rahul7star commited on
Commit
365d690
·
verified ·
1 Parent(s): b7ef6fe

Update app_qwen_tts.py

Browse files
Files changed (1) hide show
  1. app_qwen_tts.py +74 -54
app_qwen_tts.py CHANGED
@@ -1,31 +1,33 @@
1
  import os
 
2
  import torch
3
  import gradio as gr
4
  import numpy as np
5
- import requests
6
- import base64
7
- import io
8
- import soundfile as sf
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from sentence_transformers import SentenceTransformer
11
 
12
  # =========================================================
13
  # Configuration
 
14
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
15
  DOC_FILE = "general.md"
16
  MAX_NEW_TOKENS = 200
17
  TOP_K = 3
 
 
18
  TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
19
 
20
  # =========================================================
21
  # Paths
 
22
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
23
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
24
  if not os.path.exists(DOC_PATH):
25
- raise RuntimeError(f"{DOC_FILE} not found")
26
 
27
  # =========================================================
28
- # Load Qwen Model
 
29
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
30
  model = AutoModelForCausalLM.from_pretrained(
31
  MODEL_ID,
@@ -36,17 +38,19 @@ model = AutoModelForCausalLM.from_pretrained(
36
  model.eval()
37
 
38
  # =========================================================
39
- # Embeddings
 
40
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
41
 
42
  # =========================================================
43
- # Document chunking
 
44
  def chunk_text(text, chunk_size=300, overlap=50):
45
  words = text.split()
46
  chunks = []
47
  i = 0
48
  while i < len(words):
49
- chunks.append(" ".join(words[i:i+chunk_size]))
50
  i += chunk_size - overlap
51
  return chunks
52
 
@@ -57,7 +61,8 @@ DOC_CHUNKS = chunk_text(DOC_TEXT)
57
  DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, show_progress_bar=True)
58
 
59
  # =========================================================
60
- # Retrieve context
 
61
  def retrieve_context(question, k=TOP_K):
62
  q_emb = embedder.encode([question], normalize_embeddings=True)
63
  scores = np.dot(DOC_EMBEDS, q_emb[0])
@@ -65,10 +70,11 @@ def retrieve_context(question, k=TOP_K):
65
  return "\n\n".join([DOC_CHUNKS[i] for i in top_ids])
66
 
67
  # =========================================================
68
- # Extract answer
 
69
  def extract_final_answer(text: str) -> str:
70
  text = text.strip()
71
- markers = ["assistant:", "answer:", "final answer:"]
72
  for m in markers:
73
  if m.lower() in text.lower():
74
  text = text.lower().split(m, 1)[-1].strip()
@@ -77,96 +83,110 @@ def extract_final_answer(text: str) -> str:
77
 
78
  # =========================================================
79
  # Qwen inference
 
80
  def answer_question(question):
81
  context = retrieve_context(question)
82
  messages = [
83
  {"role": "system", "content": (
84
  "You are a strict document-based Q&A assistant.\n"
85
- "Answer ONLY the question in 1-2 sentences.\n"
86
- "If not found, say 'I could not find this information in the document.'"
 
 
 
87
  )},
88
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
89
  ]
90
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
91
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
92
  with torch.no_grad():
93
  output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.3, do_sample=True)
 
94
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
95
  return extract_final_answer(decoded)
96
 
97
  # =========================================================
98
- # TTS via API
99
- def tts_via_api(text: str):
 
 
 
 
 
 
 
 
 
100
  try:
101
- resp = requests.post(TTS_API_URL, json={"text": text}, timeout=60)
102
  resp.raise_for_status()
103
- audio_b64 = resp.json().get("audio", "")
 
104
  if not audio_b64:
105
  return None
106
- audio_bytes = base64.b64decode(audio_b64.split(",")[-1])
107
- wav, sr = sf.read(io.BytesIO(audio_bytes), dtype='float32')
108
- return wav, sr
109
  except Exception as e:
110
  print(f"TTS API error: {e}")
111
  return None
112
 
113
  # =========================================================
114
- # Chat function (text + audio separate boxes)
 
115
  def chat(user_message, history):
116
  if not user_message.strip():
117
  return "", history
 
118
  try:
119
- # 1️⃣ Text answer
120
  answer_text = answer_question(user_message)
121
 
122
- # 2️⃣ Audio
123
- tts_result = tts_via_api(answer_text)
124
- if tts_result is not None:
125
- wav, sr = tts_result
126
- audio_output = (sr, wav)
127
- else:
128
- audio_output = None
129
-
130
- # 3️⃣ Append as separate text + audio
131
- history.append((user_message, answer_text, audio_output))
132
-
 
 
 
 
133
  except Exception as e:
134
  print(e)
135
- history.append((user_message, "⚠️ Error generating answer or audio.", None))
 
 
 
 
136
  return "", history
137
 
138
  def reset_chat():
139
  return []
140
 
141
  # =========================================================
142
- # Gradio UI
 
143
  def build_ui():
144
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
145
- gr.Markdown("# 📄 Qwen Document Assistant + TTS\nAsk a question and get a text + playable audio response.")
146
-
147
- chatbot = gr.Chatbot(height=500, type="messages") # 'messages' so we can use custom formatting
148
-
149
  msg = gr.Textbox(placeholder="Ask a question...", lines=2)
150
  send = gr.Button("Send")
151
  clear = gr.Button("🧹 Clear")
152
 
153
- def format_history(history):
154
- formatted = []
155
- for user_msg, bot_text, bot_audio in history:
156
- formatted.append([f"**You:** {user_msg}", None])
157
- formatted.append([f"**Bot:** {bot_text}", bot_audio])
158
- return formatted
159
-
160
- def chat_with_format(msg_input, history):
161
- _, history = chat(msg_input, history)
162
- return "", format_history(history)
163
-
164
- send.click(chat_with_format, [msg, chatbot], [msg, chatbot])
165
- msg.submit(chat_with_format, [msg, chatbot], [msg, chatbot])
166
  clear.click(reset_chat, outputs=chatbot)
167
 
168
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
169
 
 
 
170
  # =========================================================
171
  if __name__ == "__main__":
172
  print(f"✅ Loaded {len(DOC_CHUNKS)} chunks from {DOC_FILE}")
 
1
  import os
2
+ import requests
3
  import torch
4
  import gradio as gr
5
  import numpy as np
 
 
 
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from sentence_transformers import SentenceTransformer
8
 
9
  # =========================================================
10
  # Configuration
11
+ # =========================================================
12
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
13
  DOC_FILE = "general.md"
14
  MAX_NEW_TOKENS = 200
15
  TOP_K = 3
16
+
17
+ # Your TTS FastAPI endpoint
18
  TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
19
 
20
  # =========================================================
21
  # Paths
22
+ # =========================================================
23
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
25
  if not os.path.exists(DOC_PATH):
26
+ raise RuntimeError(f"{DOC_FILE} not found next to app.py")
27
 
28
  # =========================================================
29
+ # Load Qwen model
30
+ # =========================================================
31
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
32
  model = AutoModelForCausalLM.from_pretrained(
33
  MODEL_ID,
 
38
  model.eval()
39
 
40
  # =========================================================
41
+ # Embedding Model for retrieval
42
+ # =========================================================
43
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
44
 
45
  # =========================================================
46
+ # Load document & chunk
47
+ # =========================================================
48
  def chunk_text(text, chunk_size=300, overlap=50):
49
  words = text.split()
50
  chunks = []
51
  i = 0
52
  while i < len(words):
53
+ chunks.append(" ".join(words[i:i + chunk_size]))
54
  i += chunk_size - overlap
55
  return chunks
56
 
 
61
  DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, show_progress_bar=True)
62
 
63
  # =========================================================
64
+ # Retrieval
65
+ # =========================================================
66
  def retrieve_context(question, k=TOP_K):
67
  q_emb = embedder.encode([question], normalize_embeddings=True)
68
  scores = np.dot(DOC_EMBEDS, q_emb[0])
 
70
  return "\n\n".join([DOC_CHUNKS[i] for i in top_ids])
71
 
72
  # =========================================================
73
+ # Extract final answer
74
+ # =========================================================
75
  def extract_final_answer(text: str) -> str:
76
  text = text.strip()
77
+ markers = ["assistant:", "assistant", "answer:", "final answer:"]
78
  for m in markers:
79
  if m.lower() in text.lower():
80
  text = text.lower().split(m, 1)[-1].strip()
 
83
 
84
  # =========================================================
85
  # Qwen inference
86
+ # =========================================================
87
  def answer_question(question):
88
  context = retrieve_context(question)
89
  messages = [
90
  {"role": "system", "content": (
91
  "You are a strict document-based Q&A assistant.\n"
92
+ "Answer ONLY the question.\n"
93
+ "Do NOT repeat the context or the question.\n"
94
+ "Respond in 1–2 sentences.\n"
95
+ "If the answer is not present, say:\n"
96
+ "'I could not find this information in the document.'"
97
  )},
98
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
99
  ]
100
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
101
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
102
+
103
  with torch.no_grad():
104
  output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.3, do_sample=True)
105
+
106
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
107
  return extract_final_answer(decoded)
108
 
109
  # =========================================================
110
+ # TTS via FastAPI
111
+ # =========================================================
112
+ def tts_via_api(text: str, language_id="en", mode="Speak 🗣️", exaggeration=0.5, temperature=0.8, cfg_weight=0.5):
113
+ payload = {
114
+ "text": text,
115
+ "language_id": language_id,
116
+ "mode": mode,
117
+ "exaggeration": exaggeration,
118
+ "temperature": temperature,
119
+ "cfg_weight": cfg_weight
120
+ }
121
  try:
122
+ resp = requests.post(TTS_API_URL, json=payload, timeout=60)
123
  resp.raise_for_status()
124
+ data = resp.json()
125
+ audio_b64 = data.get("audio", "")
126
  if not audio_b64:
127
  return None
128
+ return f"data:audio/wav;base64,{audio_b64}"
 
 
129
  except Exception as e:
130
  print(f"TTS API error: {e}")
131
  return None
132
 
133
  # =========================================================
134
+ # Chat function
135
+ # =========================================================
136
  def chat(user_message, history):
137
  if not user_message.strip():
138
  return "", history
139
+
140
  try:
141
+ # 1️⃣ Generate answer
142
  answer_text = answer_question(user_message)
143
 
144
+ # 2️⃣ Generate audio
145
+ audio_data = tts_via_api(answer_text)
146
+
147
+ # 3️⃣ Append formatted message
148
+ history.append({
149
+ "role": "user",
150
+ "content": user_message
151
+ })
152
+ history.append({
153
+ "role": "assistant",
154
+ "content": [
155
+ gr.Markdown.update(value=f"**Bot:** {answer_text}"),
156
+ gr.Audio.update(value=audio_data, interactive=False) if audio_data else None
157
+ ]
158
+ })
159
  except Exception as e:
160
  print(e)
161
+ history.append({
162
+ "role": "assistant",
163
+ "content": "**⚠️ Error generating response.**"
164
+ })
165
+
166
  return "", history
167
 
168
  def reset_chat():
169
  return []
170
 
171
  # =========================================================
172
+ # Build UI
173
+ # =========================================================
174
  def build_ui():
175
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
176
+ gr.Markdown("## 📄 Qwen Document Assistant + TTS\nAsk questions and listen to answers!")
177
+ chatbot = gr.Chatbot(height=500, type="messages")
 
 
178
  msg = gr.Textbox(placeholder="Ask a question...", lines=2)
179
  send = gr.Button("Send")
180
  clear = gr.Button("🧹 Clear")
181
 
182
+ send.click(chat, [msg, chatbot], [msg, chatbot])
183
+ msg.submit(chat, [msg, chatbot], [msg, chatbot])
 
 
 
 
 
 
 
 
 
 
 
184
  clear.click(reset_chat, outputs=chatbot)
185
 
186
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
187
 
188
+ # =========================================================
189
+ # Entrypoint
190
  # =========================================================
191
  if __name__ == "__main__":
192
  print(f"✅ Loaded {len(DOC_CHUNKS)} chunks from {DOC_FILE}")