rahul7star commited on
Commit
ad2de89
·
verified ·
1 Parent(s): a7ced7c

Update app_qwen_tts.py

Browse files
Files changed (1) hide show
  1. app_qwen_tts.py +37 -60
app_qwen_tts.py CHANGED
@@ -1,12 +1,9 @@
1
  import os
2
- import io
3
- import base64
4
- import requests
5
  import torch
6
- import numpy as np
7
- import soundfile as sf
8
  import gradio as gr
9
-
 
 
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from sentence_transformers import SentenceTransformer
12
 
@@ -17,7 +14,7 @@ MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
17
  DOC_FILE = "general.md"
18
  MAX_NEW_TOKENS = 200
19
  TOP_K = 3
20
- TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts" # FastAPI TTS endpoint
21
 
22
  # =========================================================
23
  # Paths
@@ -31,9 +28,7 @@ if not os.path.exists(DOC_PATH):
31
  # =========================================================
32
  # Load Qwen Model
33
  # =========================================================
34
- print("🔄 Loading Qwen model...")
35
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
36
-
37
  model = AutoModelForCausalLM.from_pretrained(
38
  MODEL_ID,
39
  device_map="auto",
@@ -41,7 +36,6 @@ model = AutoModelForCausalLM.from_pretrained(
41
  trust_remote_code=True
42
  )
43
  model.eval()
44
- print("✅ Qwen model loaded.")
45
 
46
  # =========================================================
47
  # Embedding Model
@@ -49,7 +43,7 @@ print("✅ Qwen model loaded.")
49
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
50
 
51
  # =========================================================
52
- # Document Chunking
53
  # =========================================================
54
  def chunk_text(text, chunk_size=300, overlap=50):
55
  words = text.split()
@@ -68,7 +62,7 @@ DOC_CHUNKS = chunk_text(DOC_TEXT)
68
  DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, show_progress_bar=True)
69
 
70
  # =========================================================
71
- # Retrieval
72
  # =========================================================
73
  def retrieve_context(question, k=TOP_K):
74
  q_emb = embedder.encode([question], normalize_embeddings=True)
@@ -89,20 +83,23 @@ def extract_final_answer(text: str) -> str:
89
  return lines[-1] if lines else text
90
 
91
  # =========================================================
92
- # Qwen Inference
93
  # =========================================================
94
  def answer_question(question):
95
  context = retrieve_context(question)
96
 
97
  messages = [
98
- {"role": "system", "content": (
99
- "You are a strict document-based Q&A assistant.\n"
100
- "Answer ONLY the question.\n"
101
- "Do NOT repeat the context or the question.\n"
102
- "Respond in 1–2 sentences.\n"
103
- "If the answer is not present, say:\n"
104
- "'I could not find this information in the document.'"
105
- )},
 
 
 
106
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
107
  ]
108
 
@@ -110,74 +107,55 @@ def answer_question(question):
110
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
111
 
112
  with torch.no_grad():
113
- output = model.generate(
114
- **inputs,
115
- max_new_tokens=MAX_NEW_TOKENS,
116
- temperature=0.3,
117
- do_sample=True
118
- )
119
 
120
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
121
  return extract_final_answer(decoded)
122
 
123
  # =========================================================
124
- # TTS via API (returns path to WAV for Gradio)
125
  # =========================================================
126
- def tts_via_api(text: str, language_id="en", mode="Speak 🗣️", exaggeration=0.5, temperature=0.8, cfg_weight=0.5):
127
- payload = {
128
- "text": text,
129
- "language_id": language_id,
130
- "mode": mode,
131
- "exaggeration": exaggeration,
132
- "temperature": temperature,
133
- "cfg_weight": cfg_weight
134
- }
135
  try:
 
136
  resp = requests.post(TTS_API_URL, json=payload, timeout=60)
137
  resp.raise_for_status()
138
  data = resp.json()
139
  audio_b64 = data.get("audio", "")
140
  if not audio_b64:
141
  return None
142
-
143
- # Convert base64 to WAV for Gradio
144
  audio_bytes = base64.b64decode(audio_b64)
145
- audio_buffer = io.BytesIO(audio_bytes)
146
- wav, sr = sf.read(audio_buffer, dtype="float32")
147
- temp_path = "/tmp/response.wav"
148
- sf.write(temp_path, wav, sr)
149
- return temp_path
150
-
151
  except Exception as e:
152
  print(f"TTS API error: {e}")
153
  return None
154
 
155
  # =========================================================
156
- # Chat function for Gradio
157
  # =========================================================
158
  def chat(user_message, history):
159
  if not user_message.strip():
160
  return "", history
161
 
162
  try:
163
- # Generate text answer
164
  answer_text = answer_question(user_message)
165
 
166
  # Generate audio
167
  audio_path = tts_via_api(answer_text)
168
 
169
- # Append user message
170
- history.append({"role": "user", "content": user_message})
171
 
172
- # Append assistant message with text + audio
173
- if audio_path:
174
- history.append({"role": "assistant", "content": [f"**Bot:** {answer_text}", audio_path]})
175
- else:
176
- history.append({"role": "assistant", "content": f"**Bot:** {answer_text}"})
177
 
178
  except Exception as e:
179
  print(e)
180
- history.append({"role": "assistant", "content": "**⚠️ Error generating response.**"})
181
 
182
  return "", history
183
 
@@ -185,14 +163,13 @@ def reset_chat():
185
  return []
186
 
187
  # =========================================================
188
- # Build Gradio UI
189
  # =========================================================
190
  def build_ui():
191
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
192
- gr.Markdown("# 📄 Qwen Document Assistant + TTS")
193
- gr.Markdown("Ask questions and hear the answers as audio.")
194
 
195
- chatbot = gr.Chatbot(height=450, type="messages")
196
  msg = gr.Textbox(placeholder="Ask a question...", lines=2)
197
  send = gr.Button("Send")
198
  clear = gr.Button("🧹 Clear")
@@ -201,7 +178,7 @@ def build_ui():
201
  msg.submit(chat, [msg, chatbot], [msg, chatbot])
202
  clear.click(reset_chat, outputs=chatbot)
203
 
204
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
205
 
206
  # =========================================================
207
  # Entrypoint
 
1
  import os
 
 
 
2
  import torch
 
 
3
  import gradio as gr
4
+ import numpy as np
5
+ import base64
6
+ import requests
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from sentence_transformers import SentenceTransformer
9
 
 
14
  DOC_FILE = "general.md"
15
  MAX_NEW_TOKENS = 200
16
  TOP_K = 3
17
+ TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
18
 
19
  # =========================================================
20
  # Paths
 
28
  # =========================================================
29
  # Load Qwen Model
30
  # =========================================================
 
31
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
32
  model = AutoModelForCausalLM.from_pretrained(
33
  MODEL_ID,
34
  device_map="auto",
 
36
  trust_remote_code=True
37
  )
38
  model.eval()
 
39
 
40
  # =========================================================
41
  # Embedding Model
 
43
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
44
 
45
  # =========================================================
46
+ # Load Document
47
  # =========================================================
48
  def chunk_text(text, chunk_size=300, overlap=50):
49
  words = text.split()
 
62
  DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, show_progress_bar=True)
63
 
64
  # =========================================================
65
+ # Retrieve context
66
  # =========================================================
67
  def retrieve_context(question, k=TOP_K):
68
  q_emb = embedder.encode([question], normalize_embeddings=True)
 
83
  return lines[-1] if lines else text
84
 
85
  # =========================================================
86
+ # Generate text answer
87
  # =========================================================
88
  def answer_question(question):
89
  context = retrieve_context(question)
90
 
91
  messages = [
92
+ {
93
+ "role": "system",
94
+ "content": (
95
+ "You are a strict document-based Q&A assistant.\n"
96
+ "Answer ONLY the question.\n"
97
+ "Do NOT repeat the context or the question.\n"
98
+ "Respond in 1–2 sentences.\n"
99
+ "If the answer is not present, say:\n"
100
+ "'I could not find this information in the document.'"
101
+ )
102
+ },
103
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
104
  ]
105
 
 
107
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
108
 
109
  with torch.no_grad():
110
+ output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.3, do_sample=True)
 
 
 
 
 
111
 
112
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
113
  return extract_final_answer(decoded)
114
 
115
  # =========================================================
116
+ # Call TTS API and get audio path
117
  # =========================================================
118
+ def tts_via_api(text: str):
 
 
 
 
 
 
 
 
119
  try:
120
+ payload = {"text": text}
121
  resp = requests.post(TTS_API_URL, json=payload, timeout=60)
122
  resp.raise_for_status()
123
  data = resp.json()
124
  audio_b64 = data.get("audio", "")
125
  if not audio_b64:
126
  return None
127
+ audio_path = "/tmp/output.wav"
 
128
  audio_bytes = base64.b64decode(audio_b64)
129
+ with open(audio_path, "wb") as f:
130
+ f.write(audio_bytes)
131
+ return audio_path
 
 
 
132
  except Exception as e:
133
  print(f"TTS API error: {e}")
134
  return None
135
 
136
  # =========================================================
137
+ # Gradio Chat function
138
  # =========================================================
139
  def chat(user_message, history):
140
  if not user_message.strip():
141
  return "", history
142
 
143
  try:
144
+ # Generate text
145
  answer_text = answer_question(user_message)
146
 
147
  # Generate audio
148
  audio_path = tts_via_api(answer_text)
149
 
150
+ # Append tuple: (text, audio)
151
+ history.append((f"**Bot:** {answer_text}", audio_path))
152
 
153
+ # Append user message
154
+ history.append((f"**You:** {user_message}", None))
 
 
 
155
 
156
  except Exception as e:
157
  print(e)
158
+ history.append(("⚠️ Error generating response", None))
159
 
160
  return "", history
161
 
 
163
  return []
164
 
165
  # =========================================================
166
+ # Build UI
167
  # =========================================================
168
  def build_ui():
169
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
170
+ gr.Markdown("# 📄 Qwen Document Assistant + TTS\nAsk questions and listen to answers.")
 
171
 
172
+ chatbot = gr.Chatbot(height=450, type="tuples")
173
  msg = gr.Textbox(placeholder="Ask a question...", lines=2)
174
  send = gr.Button("Send")
175
  clear = gr.Button("🧹 Clear")
 
178
  msg.submit(chat, [msg, chatbot], [msg, chatbot])
179
  clear.click(reset_chat, outputs=chatbot)
180
 
181
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
182
 
183
  # =========================================================
184
  # Entrypoint