Rishitha3 commited on
Commit
fe0a12e
Β·
verified Β·
1 Parent(s): d1decfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -65
app.py CHANGED
@@ -1,70 +1,182 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
  )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
68
 
69
- if __name__ == "__main__":
70
- demo.launch()
 
1
+ import os
2
  import gradio as gr
3
+ import fitz # PyMuPDF for PDFs
4
+ import docx
5
+ import faiss
6
+ import numpy as np
7
+ from sentence_transformers import SentenceTransformer
8
+ from transformers import pipeline
9
+ from gtts import gTTS # βœ… gTTS for speech
10
+
11
+ # =============================
12
+ # 1) Auth & Config
13
+ # =============================
14
+ HF_TOKEN = os.getenv("HF_TOKEN")
15
+ if HF_TOKEN is None:
16
+ raise ValueError("⚠️ Please set your HF_TOKEN as an environment variable.")
17
+
18
+ EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
19
+ LLM_MODEL_ID = "meta-llama/Llama-3.2-1b-instruct" # βœ… you can swap with smaller model for more speed
20
+ ASR_MODEL_ID = "openai/whisper-small"
21
+
22
+ # =============================
23
+ # 2) Load Models
24
+ # =============================
25
+ # Embeddings
26
+ embedding_model = SentenceTransformer(EMBED_MODEL_ID)
27
+
28
+ # LLM (no HyDE, just final answers)
29
+ qa_model = pipeline(
30
+ "text-generation",
31
+ model=LLM_MODEL_ID,
32
+ token=HF_TOKEN,
33
+ device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
 
36
+ # Speech-to-Text
37
+ stt_model = pipeline(
38
+ "automatic-speech-recognition",
39
+ model=ASR_MODEL_ID,
40
+ token=HF_TOKEN
41
+ )
42
+
43
+ # =============================
44
+ # 3) File Text Extraction
45
+ # =============================
46
+ def extract_text(file_path: str) -> str:
47
+ if not file_path:
48
+ return ""
49
+ _, ext = os.path.splitext(file_path.lower())
50
+ text = ""
51
+ if ext == ".pdf":
52
+ doc = fitz.open(file_path)
53
+ for page in doc:
54
+ text += page.get_text("text")
55
+ elif ext == ".docx":
56
+ doc = docx.Document(file_path)
57
+ for para in doc.paragraphs:
58
+ text += para.text + "\n"
59
+ else:
60
+ with open(file_path, "rb") as f:
61
+ text = f.read().decode("utf-8", errors="ignore")
62
+ return text
63
+
64
+ # =============================
65
+ # 4) Build FAISS Index
66
+ # =============================
67
+ def build_faiss(text: str, chunk_size=500, overlap=50):
68
+ if not text.strip():
69
+ return None, None
70
+
71
+ chunks = []
72
+ step = max(1, chunk_size - overlap)
73
+ for i in range(0, len(text), step):
74
+ chunk = text[i:i + chunk_size]
75
+ if chunk.strip():
76
+ chunks.append(chunk)
77
+
78
+ if not chunks:
79
+ return None, None
80
+
81
+ embeddings = embedding_model.encode(chunks, convert_to_numpy=True, normalize_embeddings=True)
82
+ dim = embeddings.shape[1]
83
+ index = faiss.IndexFlatIP(dim)
84
+ index.add(embeddings)
85
+ return index, chunks
86
+
87
+ # =============================
88
+ # 5) Globals
89
+ # =============================
90
+ doc_index = None
91
+ doc_chunks = None
92
+
93
+ # =============================
94
+ # 6) Handlers
95
+ # =============================
96
+ def upload_file(file_path: str):
97
+ global doc_index, doc_chunks
98
+ if not file_path:
99
+ return "⚠️ Please upload a file first."
100
+ text = extract_text(file_path)
101
+ idx, chunks = build_faiss(text)
102
+ if idx is None:
103
+ return "⚠️ Could not index: file appears empty."
104
+ doc_index, doc_chunks = idx, chunks
105
+ return f"βœ… Document indexed! {len(chunks)} chunks ready."
106
+
107
+ def answer_query(query: str):
108
+ global doc_index, doc_chunks
109
+ if not query or not query.strip():
110
+ return "⚠️ Please enter a question."
111
+ if doc_index is None or not doc_chunks:
112
+ return "⚠️ Please upload and index a document first."
113
+
114
+ # Embed query directly
115
+ q_vec = embedding_model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
116
+ D, I = doc_index.search(q_vec, k=min(5, len(doc_chunks)))
117
+ retrieved = [doc_chunks[i] for i in I[0] if 0 <= i < len(doc_chunks)]
118
+
119
+ context = "\n\n".join(retrieved)
120
+ final_prompt = (
121
+ "You are a helpful assistant. Answer based only on the context. "
122
+ "If the answer is not in the context, say you don't know.\n\n"
123
+ f"Context:\n{context}\n\nQuestion: {query}\nAnswer:"
124
+ )
125
+ out = qa_model(final_prompt, max_new_tokens=200, do_sample=False)[0]["generated_text"]
126
+ return out
127
+
128
+ def synthesize_with_gtts(text: str, out_path="out.mp3"):
129
+ """Generate speech from text and save to mp3 using gTTS."""
130
+ tts = gTTS(text=text, lang="en")
131
+ tts.save(out_path)
132
+ return out_path
133
+
134
+ def voice_query(audio_path: str):
135
+ if not audio_path:
136
+ return "⚠️ Please record your question.", "", None
137
+
138
+ # 1) Speech -> Text
139
+ asr = stt_model(audio_path)
140
+ recognized = asr.get("text", "").strip()
141
+ if not recognized:
142
+ return "⚠️ Could not transcribe audio.", "", None
143
+
144
+ # 2) RAG Answer
145
+ ans = answer_query(recognized)
146
+
147
+ # 3) Text -> Speech (gTTS saves mp3 file)
148
+ mp3_path = synthesize_with_gtts(ans, "answer.mp3")
149
+
150
+ return recognized, ans, mp3_path
151
+
152
+ # =============================
153
+ # 7) Gradio UI
154
+ # =============================
155
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="cyan")) as demo:
156
+ gr.Markdown("# πŸ“š Simple RAG Chatbot + 🎀 Voice")
157
+ gr.Markdown("Upload a PDF/DOCX/TXT and ask by typing **or** speaking. Uses Whisper for ASR and gTTS for speech output.")
158
+
159
+ with gr.Row():
160
+ with gr.Column(scale=1):
161
+ file_input = gr.File(label="πŸ“‚ Upload Document", type="filepath")
162
+ upload_btn = gr.Button("⚑ Index Document", variant="primary")
163
+ status = gr.Textbox(label="Status", interactive=False)
164
+
165
+ with gr.Column(scale=2):
166
+ gr.Markdown("### ✍️ Text Chat")
167
+ query = gr.Textbox(label="❓ Ask a Question", placeholder="e.g., What are the key points?")
168
+ ask_btn = gr.Button("πŸš€ Get Answer", variant="primary")
169
+ answer = gr.Textbox(label="πŸ’‘ Answer", lines=8)
170
+
171
+ gr.Markdown("### 🎀 Voice Chat")
172
+ mic_input = gr.Audio(sources=["microphone"], type="filepath", label="Speak your question")
173
+ rec_text = gr.Textbox(label="πŸ“ Recognized Speech", interactive=False)
174
+ v_answer = gr.Textbox(label="πŸ’‘ Answer (from voice)", lines=8)
175
+ v_audio = gr.Audio(label="πŸ”Š Bot Voice Reply")
176
 
177
+ # Bind events
178
+ upload_btn.click(fn=upload_file, inputs=file_input, outputs=status)
179
+ ask_btn.click(fn=answer_query, inputs=query, outputs=answer)
180
+ mic_input.change(fn=voice_query, inputs=mic_input, outputs=[rec_text, v_answer, v_audio])
181
 
182
+ demo.launch()