rahul7star commited on
Commit
9777df5
·
verified ·
1 Parent(s): 3067843

Update app_qwen_tts.py

Browse files
Files changed (1) hide show
  1. app_qwen_tts.py +77 -68
app_qwen_tts.py CHANGED
@@ -1,36 +1,37 @@
1
  import os
2
- import requests
3
  import torch
4
- import numpy as np
5
  import gradio as gr
 
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from sentence_transformers import SentenceTransformer
8
 
9
  # =========================================================
10
- # Config
11
  # =========================================================
12
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
13
  DOC_FILE = "general.md"
 
14
  MAX_NEW_TOKENS = 200
15
  TOP_K = 3
16
-
17
  TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts" # FastAPI TTS endpoint
18
 
19
  # =========================================================
20
- # Load document
21
  # =========================================================
22
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
23
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
 
24
  if not os.path.exists(DOC_PATH):
25
  raise RuntimeError(f"❌ {DOC_FILE} not found next to app.py")
26
 
27
- with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f:
28
- DOC_TEXT = f.read()
29
-
30
  # =========================================================
31
- # Qwen model
32
  # =========================================================
33
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
34
  model = AutoModelForCausalLM.from_pretrained(
35
  MODEL_ID,
36
  device_map="auto",
@@ -40,29 +41,41 @@ model = AutoModelForCausalLM.from_pretrained(
40
  model.eval()
41
 
42
  # =========================================================
43
- # Embeddings for retrieval
44
  # =========================================================
45
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
46
 
 
 
 
47
  def chunk_text(text, chunk_size=300, overlap=50):
48
  words = text.split()
49
  chunks = []
50
  i = 0
51
  while i < len(words):
52
- chunk = words[i:i+chunk_size]
53
  chunks.append(" ".join(chunk))
54
  i += chunk_size - overlap
55
  return chunks
56
 
 
 
 
57
  DOC_CHUNKS = chunk_text(DOC_TEXT)
58
  DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, show_progress_bar=True)
59
 
 
 
 
60
  def retrieve_context(question, k=TOP_K):
61
  q_emb = embedder.encode([question], normalize_embeddings=True)
62
  scores = np.dot(DOC_EMBEDS, q_emb[0])
63
  top_ids = scores.argsort()[-k:][::-1]
64
  return "\n\n".join([DOC_CHUNKS[i] for i in top_ids])
65
 
 
 
 
66
  def extract_final_answer(text: str) -> str:
67
  text = text.strip()
68
  markers = ["assistant:", "assistant", "answer:", "final answer:"]
@@ -73,30 +86,46 @@ def extract_final_answer(text: str) -> str:
73
  return lines[-1] if lines else text
74
 
75
  # =========================================================
76
- # Qwen inference
77
  # =========================================================
78
- def answer_question(question: str) -> str:
79
  context = retrieve_context(question)
 
80
  messages = [
81
- {"role": "system", "content": (
82
- "You are a strict document-based Q&A assistant.\n"
83
- "Answer ONLY the question.\n"
84
- "Do NOT repeat the context or the question.\n"
85
- "Respond in 1–2 sentences.\n"
86
- "If the answer is not present, say:\n"
87
- "'I could not find this information in the document.'"
88
- )},
89
- {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
 
 
 
 
 
 
90
  ]
 
91
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
92
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
93
  with torch.no_grad():
94
- output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.3, do_sample=True)
 
 
 
 
 
 
95
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
96
  return extract_final_answer(decoded)
97
 
98
  # =========================================================
99
- # Call TTS API
 
100
  # =========================================================
101
  def tts_via_api(text: str, language_id="en", mode="Speak 🗣️", exaggeration=0.5, temperature=0.8, cfg_weight=0.5):
102
  payload = {
@@ -120,37 +149,33 @@ def tts_via_api(text: str, language_id="en", mode="Speak 🗣️", exaggeration=
120
  return None
121
 
122
  # =========================================================
123
- # Gradio Chat
124
  # =========================================================
125
- def chat(user_message, history, language_id, mode, exaggeration, temperature, cfg_weight):
126
  if not user_message.strip():
127
  return "", history
 
128
  try:
129
- # 1️⃣ Get Qwen answer
130
  answer_text = answer_question(user_message)
131
 
132
- # 2️⃣ Get TTS from API
133
- audio_src = tts_via_api(answer_text, language_id, mode, exaggeration, temperature, cfg_weight)
 
 
 
 
 
134
 
135
- # 3️⃣ Format bot message nicely with spacing
136
- if audio_src:
137
- # Use a small HTML wrapper for spacing
138
- bot_content = [
139
- f"<div style='margin-bottom:8px;'>{answer_text}</div>", # text with margin
140
- audio_src # playable audio below
141
- ]
142
- else:
143
- bot_content = [answer_text]
144
 
145
  except Exception as e:
146
  print(e)
147
- bot_content = ["⚠️ Error generating answer or audio."]
148
 
149
- # 4️⃣ Append as tuple: (user_message, bot_content)
150
- history.append((user_message, bot_content))
151
  return "", history
152
 
153
-
154
  def reset_chat():
155
  return []
156
 
@@ -159,38 +184,22 @@ def reset_chat():
159
  # =========================================================
160
  def build_ui():
161
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
162
- gr.Markdown("## 🧠 Qwen Document Chat + TTS (API-based)")
163
-
164
  chatbot = gr.Chatbot(height=450, type="tuples")
165
-
166
- with gr.Row():
167
- msg = gr.Textbox(placeholder="Ask a question...", lines=2, scale=8)
168
- send = gr.Button("🚀 Send", scale=2)
169
-
170
- with gr.Row():
171
- language_id = gr.Dropdown(["en","fr","hi","he"], value="en", label="TTS Language")
172
- mode = gr.Radio(["Speak 🗣️", "Sing 🎵"], value="Speak 🗣️", label="TTS Mode")
173
- exaggeration = gr.Slider(0.25, 2.0, step=0.05, value=0.5, label="Exaggeration")
174
- temperature = gr.Slider(0.1, 2.0, step=0.05, value=0.8, label="Temperature")
175
- cfg_weight = gr.Slider(0.2, 1.0, step=0.05, value=0.5, label="CFG / Pace")
176
-
177
  clear = gr.Button("🧹 Clear")
178
 
179
- send.click(
180
- chat,
181
- [msg, chatbot, language_id, mode, exaggeration, temperature, cfg_weight],
182
- [msg, chatbot]
183
- )
184
- msg.submit(
185
- chat,
186
- [msg, chatbot, language_id, mode, exaggeration, temperature, cfg_weight],
187
- [msg, chatbot]
188
- )
189
  clear.click(reset_chat, outputs=chatbot)
190
 
191
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
192
 
 
 
 
193
  if __name__ == "__main__":
194
  print(f"✅ Loaded {len(DOC_CHUNKS)} chunks from {DOC_FILE}")
195
- print(f"✅ Qwen Model: {MODEL_ID}")
196
  build_ui()
 
1
  import os
 
2
  import torch
 
3
  import gradio as gr
4
+ import numpy as np
5
+ import soundfile as sf
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from sentence_transformers import SentenceTransformer
8
 
9
  # =========================================================
10
+ # Configuration
11
  # =========================================================
12
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
13
  DOC_FILE = "general.md"
14
+
15
  MAX_NEW_TOKENS = 200
16
  TOP_K = 3
 
17
  TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts" # FastAPI TTS endpoint
18
 
19
  # =========================================================
20
+ # Paths
21
  # =========================================================
22
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
23
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
24
+
25
  if not os.path.exists(DOC_PATH):
26
  raise RuntimeError(f"❌ {DOC_FILE} not found next to app.py")
27
 
 
 
 
28
  # =========================================================
29
+ # Load Qwen Model
30
  # =========================================================
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ MODEL_ID, trust_remote_code=True
33
+ )
34
+
35
  model = AutoModelForCausalLM.from_pretrained(
36
  MODEL_ID,
37
  device_map="auto",
 
41
  model.eval()
42
 
43
  # =========================================================
44
+ # Embedding Model
45
  # =========================================================
46
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
47
 
48
+ # =========================================================
49
+ # Document Chunking
50
+ # =========================================================
51
  def chunk_text(text, chunk_size=300, overlap=50):
52
  words = text.split()
53
  chunks = []
54
  i = 0
55
  while i < len(words):
56
+ chunk = words[i:i + chunk_size]
57
  chunks.append(" ".join(chunk))
58
  i += chunk_size - overlap
59
  return chunks
60
 
61
+ with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f:
62
+ DOC_TEXT = f.read()
63
+
64
  DOC_CHUNKS = chunk_text(DOC_TEXT)
65
  DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, show_progress_bar=True)
66
 
67
+ # =========================================================
68
+ # Retrieval
69
+ # =========================================================
70
  def retrieve_context(question, k=TOP_K):
71
  q_emb = embedder.encode([question], normalize_embeddings=True)
72
  scores = np.dot(DOC_EMBEDS, q_emb[0])
73
  top_ids = scores.argsort()[-k:][::-1]
74
  return "\n\n".join([DOC_CHUNKS[i] for i in top_ids])
75
 
76
+ # =========================================================
77
+ # Extract final answer
78
+ # =========================================================
79
  def extract_final_answer(text: str) -> str:
80
  text = text.strip()
81
  markers = ["assistant:", "assistant", "answer:", "final answer:"]
 
86
  return lines[-1] if lines else text
87
 
88
  # =========================================================
89
+ # Qwen Inference
90
  # =========================================================
91
+ def answer_question(question):
92
  context = retrieve_context(question)
93
+
94
  messages = [
95
+ {
96
+ "role": "system",
97
+ "content": (
98
+ "You are a strict document-based Q&A assistant.\n"
99
+ "Answer ONLY the question.\n"
100
+ "Do NOT repeat the context or the question.\n"
101
+ "Respond in 1–2 sentences.\n"
102
+ "If the answer is not present, say:\n"
103
+ "'I could not find this information in the document.'"
104
+ )
105
+ },
106
+ {
107
+ "role": "user",
108
+ "content": f"Context:\n{context}\n\nQuestion:\n{question}"
109
+ }
110
  ]
111
+
112
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
113
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
114
+
115
  with torch.no_grad():
116
+ output = model.generate(
117
+ **inputs,
118
+ max_new_tokens=MAX_NEW_TOKENS,
119
+ temperature=0.3,
120
+ do_sample=True
121
+ )
122
+
123
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
124
  return extract_final_answer(decoded)
125
 
126
  # =========================================================
127
+ # CPU-friendly TTS
128
+
129
  # =========================================================
130
  def tts_via_api(text: str, language_id="en", mode="Speak 🗣️", exaggeration=0.5, temperature=0.8, cfg_weight=0.5):
131
  payload = {
 
149
  return None
150
 
151
  # =========================================================
152
+ # Gradio Chatbot function
153
  # =========================================================
154
+ def chat(user_message, history):
155
  if not user_message.strip():
156
  return "", history
157
+
158
  try:
159
+ # 1️⃣ Generate text answer
160
  answer_text = answer_question(user_message)
161
 
162
+ # 2️⃣ Generate audio
163
+ sr, wav = tts_via_api(answer_text)
164
+
165
+ # Save temp wav for Gradio audio player
166
+ audio_path = "/tmp/output.wav"
167
+ import soundfile as sf
168
+ sf.write(audio_path, wav, sr)
169
 
170
+ # 3️⃣ Append as tuple (text + audio)
171
+ history.append((user_message, [answer_text, audio_path]))
 
 
 
 
 
 
 
172
 
173
  except Exception as e:
174
  print(e)
175
+ history.append((user_message, ["⚠️ Error generating answer or audio."]))
176
 
 
 
177
  return "", history
178
 
 
179
  def reset_chat():
180
  return []
181
 
 
184
  # =========================================================
185
  def build_ui():
186
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
187
+ gr.Markdown("## 📄 Qwen Document Assistant + TTS")
 
188
  chatbot = gr.Chatbot(height=450, type="tuples")
189
+ msg = gr.Textbox(placeholder="Ask a question...", lines=2)
190
+ send = gr.Button("Send")
 
 
 
 
 
 
 
 
 
 
191
  clear = gr.Button("🧹 Clear")
192
 
193
+ send.click(chat, [msg, chatbot], [msg, chatbot])
194
+ msg.submit(chat, [msg, chatbot], [msg, chatbot])
 
 
 
 
 
 
 
 
195
  clear.click(reset_chat, outputs=chatbot)
196
 
197
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
198
 
199
+ # =========================================================
200
+ # Entrypoint
201
+ # =========================================================
202
  if __name__ == "__main__":
203
  print(f"✅ Loaded {len(DOC_CHUNKS)} chunks from {DOC_FILE}")
204
+ print(f"✅ Model: {MODEL_ID}")
205
  build_ui()