Files changed (1) hide show
  1. app.py +118 -120
app.py CHANGED
@@ -1,89 +1,101 @@
1
  import os
2
  import gradio as gr
3
- import fitz # PyMuPDF for PDFs
4
  import docx
5
  import faiss
6
  import numpy as np
7
  import torch
 
8
  from sentence_transformers import SentenceTransformer
9
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
  from gtts import gTTS
11
  from huggingface_hub import login
12
 
13
  # =============================
14
- # 1) Auth & Config
15
  # =============================
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
- if HF_TOKEN is None:
18
- raise ValueError("⚠️ Please set your HF_TOKEN as an environment variable.")
 
 
19
 
20
  EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
21
- LLM_MODEL_ID = "meta-llama/Llama-3.2-3b-instruct"
22
  ASR_MODEL_ID = "openai/whisper-small"
23
 
24
  # =============================
25
- # 2) Load Models
26
  # =============================
27
  embedding_model = SentenceTransformer(EMBED_MODEL_ID)
28
 
29
- login(HF_TOKEN)
30
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, token=HF_TOKEN)
31
- llm = AutoModelForCausalLM.from_pretrained(
32
- LLM_MODEL_ID,
33
- device_map="auto",
34
- torch_dtype=torch.float16,
35
  token=HF_TOKEN
36
  )
37
 
38
- # Whisper (speech-to-text)
39
- stt_model = pipeline("automatic-speech-recognition", model=ASR_MODEL_ID, token=HF_TOKEN)
40
-
41
  # =============================
42
- # 3) File Text Extraction
43
  # =============================
44
  def extract_text(file_path: str) -> str:
45
  if not file_path:
46
  return ""
47
- _, ext = os.path.splitext(file_path.lower())
48
  text = ""
49
- if ext == ".pdf":
50
- doc = fitz.open(file_path)
51
- for page in doc:
52
- text += page.get_text("text")
53
- elif ext == ".docx":
54
- doc = docx.Document(file_path)
55
- for para in doc.paragraphs:
56
- text += para.text + "\n"
57
- else:
58
- with open(file_path, "rb") as f:
59
- text = f.read().decode("utf-8", errors="ignore")
60
- return text
 
 
 
 
 
 
61
 
62
  # =============================
63
  # 4) Build FAISS Index
64
  # =============================
65
- def build_faiss(text: str, chunk_size=500, overlap=50):
66
- if not text.strip():
67
  return None, None
68
 
69
  chunks = []
70
- step = max(1, chunk_size - overlap)
 
71
  for i in range(0, len(text), step):
72
- chunk = text[i:i + chunk_size]
73
- if chunk.strip():
74
  chunks.append(chunk)
75
 
76
  if not chunks:
77
  return None, None
78
 
79
- embeddings = embedding_model.encode(chunks, convert_to_numpy=True, normalize_embeddings=True)
80
- dim = embeddings.shape[1]
81
- index = faiss.IndexFlatIP(dim)
82
- index.add(embeddings)
 
 
 
 
 
83
  return index, chunks
84
 
85
  # =============================
86
- # 5) Globals (indexed docs)
87
  # =============================
88
  doc_index = None
89
  doc_chunks = None
@@ -91,98 +103,84 @@ doc_chunks = None
91
  # =============================
92
  # 6) Handlers
93
  # =============================
94
- def upload_file(file_path: str):
95
  global doc_index, doc_chunks
96
- if not file_path:
97
- return "⚠️ Please upload a file first."
98
  text = extract_text(file_path)
 
 
 
 
99
  idx, chunks = build_faiss(text)
 
100
  if idx is None:
101
- return "⚠️ Could not index: file appears empty."
 
102
  doc_index, doc_chunks = idx, chunks
103
- return f"✅ Document indexed! {len(chunks)} chunks ready."
104
 
105
- def answer_query(query: str):
106
- global doc_index, doc_chunks
107
- if not query or not query.strip():
108
- return "⚠️ Please enter a question."
109
- if doc_index is None or not doc_chunks:
110
- return "⚠️ Please upload and index a document first."
111
-
112
- # ---- Retrieve context ----
113
- q_vec = embedding_model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
114
- D, I = doc_index.search(q_vec, k=min(5, len(doc_chunks)))
115
- retrieved = [doc_chunks[i] for i in I[0] if 0 <= i < len(doc_chunks)]
116
- context = "\n".join(retrieved)
117
-
118
- # ---- Final Answer ----
119
- final_prompt = f"""
120
- [INST] You are a helpful tutor. Based only on the context below, answer the question.
121
- If not in context, say "I could not find this in the text."
122
- Context:
123
- {context}
124
- Question: {query}
125
- Answer: [/INST]
126
- """
127
- inputs = tokenizer(final_prompt, return_tensors="pt", truncation=True).to(llm.device)
128
- outputs = llm.generate(**inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, do_sample=True)
129
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
130
-
131
- if "Answer:" in answer:
132
- answer = answer.split("Answer:")[-1].strip()
133
- return answer
134
-
135
- def synthesize_with_gtts(text: str, out_path="out.mp3"):
136
- tts = gTTS(text=text, lang="en")
137
- tts.save(out_path)
138
- return out_path
139
-
140
- def voice_query(audio_path: str):
141
- if not audio_path:
142
- return "⚠️ Please record your question.", "", None
143
 
144
- # 1) Speech -> Text
145
- asr = stt_model(audio_path)
146
- recognized = asr.get("text", "").strip()
147
- if not recognized:
148
- return "⚠️ Could not transcribe audio.", "", None
149
 
150
- # 2) Answer Query
151
- ans = answer_query(recognized)
152
 
153
- # 3) Text -> Speech
154
- mp3_path = synthesize_with_gtts(ans, "answer.mp3")
155
 
156
- return recognized, ans, mp3_path
 
 
 
 
 
 
 
 
 
 
157
 
158
  # =============================
159
- # 7) Gradio UI
160
  # =============================
161
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="cyan")) as demo:
162
- gr.Markdown("# 📚 RAG Chatbot + 🎤 Voice (Whisper + gTTS)")
163
- gr.Markdown("Upload a PDF/DOCX/TXT and ask by typing **or** speaking.")
164
-
165
- with gr.Row():
166
- with gr.Column(scale=1):
167
- file_input = gr.File(label="📂 Upload Document", type="filepath")
168
- upload_btn = gr.Button("⚡ Index Document", variant="primary")
169
- status = gr.Textbox(label="Status", interactive=False)
170
-
171
- with gr.Column(scale=2):
172
- gr.Markdown("### ✍️ Text Chat")
173
- query = gr.Textbox(label="❓ Ask a Question", placeholder="e.g., What are the key points?")
174
- ask_btn = gr.Button("🚀 Get Answer", variant="primary")
175
- answer = gr.Textbox(label="💡 Answer", lines=8)
176
-
177
- gr.Markdown("### 🎤 Voice Chat")
178
- mic_input = gr.Audio(sources=["microphone"], type="filepath", label="🎙️ Speak your question")
179
- rec_text = gr.Textbox(label="📝 Recognized Speech", interactive=False)
180
- v_answer = gr.Textbox(label="💡 Answer (voice)", lines=8)
181
- v_audio = gr.Audio(label="🔊 Bot Voice Reply")
182
-
183
- # Bind events
184
- upload_btn.click(fn=upload_file, inputs=file_input, outputs=status)
185
- ask_btn.click(fn=answer_query, inputs=query, outputs=answer)
186
- mic_input.change(fn=voice_query, inputs=mic_input, outputs=[rec_text, v_answer, v_audio])
187
 
188
  demo.launch()
 
 
1
  import os
2
  import gradio as gr
3
+ import fitz
4
  import docx
5
  import faiss
6
  import numpy as np
7
  import torch
8
+
9
  from sentence_transformers import SentenceTransformer
10
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
11
  from gtts import gTTS
12
  from huggingface_hub import login
13
 
14
  # =============================
15
+ # 1) Config
16
  # =============================
17
  HF_TOKEN = os.getenv("HF_TOKEN")
18
+ if not HF_TOKEN:
19
+ raise ValueError("Please set HF_TOKEN in Space secrets")
20
+
21
+ login(HF_TOKEN)
22
 
23
  EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
24
+ LLM_MODEL_ID = "google/flan-t5-base"
25
  ASR_MODEL_ID = "openai/whisper-small"
26
 
27
  # =============================
28
+ # 2) Load Models (cached)
29
  # =============================
30
  embedding_model = SentenceTransformer(EMBED_MODEL_ID)
31
 
32
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
33
+ llm = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_ID)
34
+
35
+ stt_model = pipeline(
36
+ "automatic-speech-recognition",
37
+ model=ASR_MODEL_ID,
38
  token=HF_TOKEN
39
  )
40
 
 
 
 
41
  # =============================
42
+ # 3) Text Extraction
43
  # =============================
44
  def extract_text(file_path: str) -> str:
45
  if not file_path:
46
  return ""
47
+
48
  text = ""
49
+ ext = os.path.splitext(file_path)[1].lower()
50
+
51
+ try:
52
+ if ext == ".pdf":
53
+ doc = fitz.open(file_path)
54
+ for page in doc:
55
+ text += page.get_text()
56
+ elif ext == ".docx":
57
+ doc = docx.Document(file_path)
58
+ for p in doc.paragraphs:
59
+ text += p.text + "\n"
60
+ else:
61
+ with open(file_path, "r", errors="ignore") as f:
62
+ text = f.read()
63
+ except Exception:
64
+ return ""
65
+
66
+ return text.strip()
67
 
68
  # =============================
69
  # 4) Build FAISS Index
70
  # =============================
71
+ def build_faiss(text, chunk_size=500, overlap=50):
72
+ if not text:
73
  return None, None
74
 
75
  chunks = []
76
+ step = chunk_size - overlap
77
+
78
  for i in range(0, len(text), step):
79
+ chunk = text[i:i + chunk_size].strip()
80
+ if chunk:
81
  chunks.append(chunk)
82
 
83
  if not chunks:
84
  return None, None
85
 
86
+ embeds = embedding_model.encode(
87
+ chunks,
88
+ convert_to_numpy=True,
89
+ normalize_embeddings=True
90
+ )
91
+
92
+ index = faiss.IndexFlatIP(embeds.shape[1])
93
+ index.add(embeds)
94
+
95
  return index, chunks
96
 
97
  # =============================
98
+ # 5) Globals
99
  # =============================
100
  doc_index = None
101
  doc_chunks = None
 
103
  # =============================
104
  # 6) Handlers
105
  # =============================
106
+ def upload_file(file_path):
107
  global doc_index, doc_chunks
 
 
108
  text = extract_text(file_path)
109
+
110
+ if not text:
111
+ return "❌ No readable text found."
112
+
113
  idx, chunks = build_faiss(text)
114
+
115
  if idx is None:
116
+ return " Indexing failed."
117
+
118
  doc_index, doc_chunks = idx, chunks
119
+ return f"✅ Indexed {len(chunks)} chunks."
120
 
121
+ def answer_query(query):
122
+ if not query.strip():
123
+ return "⚠️ Enter a question."
124
+
125
+ if doc_index is None:
126
+ return "⚠️ Upload a document first."
127
+
128
+ q_vec = embedding_model.encode(
129
+ [query],
130
+ convert_to_numpy=True,
131
+ normalize_embeddings=True
132
+ )
133
+
134
+ _, I = doc_index.search(q_vec, k=5)
135
+ context = "\n".join(doc_chunks[i] for i in I[0])
136
+
137
+ prompt = f"""
138
+ Answer using only the context below.
139
+ If not found, say "Not in document".
140
+
141
+ Context:
142
+ {context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ Question:
145
+ {query}
146
+ """
 
 
147
 
148
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
149
+ outputs = llm.generate(**inputs, max_new_tokens=200)
150
 
151
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
152
 
153
+ def voice_query(audio_path):
154
+ if not audio_path:
155
+ return "", "", None
156
+
157
+ speech = stt_model(audio_path)["text"]
158
+ answer = answer_query(speech)
159
+
160
+ tts = gTTS(answer)
161
+ tts.save("reply.mp3")
162
+
163
+ return speech, answer, "reply.mp3"
164
 
165
  # =============================
166
+ # 7) UI
167
  # =============================
168
+ with gr.Blocks() as demo:
169
+ gr.Markdown("# 📚 RAG Chatbot with Voice")
170
+
171
+ file = gr.File(type="filepath")
172
+ status = gr.Textbox()
173
+ gr.Button("Index").click(upload_file, file, status)
174
+
175
+ query = gr.Textbox(label="Question")
176
+ answer = gr.Textbox()
177
+ gr.Button("Ask").click(answer_query, query, answer)
178
+
179
+ audio = gr.Audio(type="filepath")
180
+ rec = gr.Textbox()
181
+ v_ans = gr.Textbox()
182
+ v_audio = gr.Audio()
183
+ audio.change(voice_query, audio, [rec, v_ans, v_audio])
 
 
 
 
 
 
 
 
 
 
184
 
185
  demo.launch()
186
+