rahul7star commited on
Commit
a7ced7c
·
verified ·
1 Parent(s): 8d78afb

Update app_qwen_tts.py

Browse files
Files changed (1) hide show
  1. app_qwen_tts.py +50 -37
app_qwen_tts.py CHANGED
@@ -1,8 +1,12 @@
1
  import os
 
 
2
  import requests
3
  import torch
4
- import gradio as gr
5
  import numpy as np
 
 
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from sentence_transformers import SentenceTransformer
8
 
@@ -13,22 +17,23 @@ MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
13
  DOC_FILE = "general.md"
14
  MAX_NEW_TOKENS = 200
15
  TOP_K = 3
16
-
17
- # Your TTS FastAPI endpoint
18
- TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
19
 
20
  # =========================================================
21
  # Paths
22
  # =========================================================
23
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
 
25
  if not os.path.exists(DOC_PATH):
26
  raise RuntimeError(f"❌ {DOC_FILE} not found next to app.py")
27
 
28
  # =========================================================
29
- # Load Qwen model
30
  # =========================================================
 
31
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
32
  model = AutoModelForCausalLM.from_pretrained(
33
  MODEL_ID,
34
  device_map="auto",
@@ -36,21 +41,23 @@ model = AutoModelForCausalLM.from_pretrained(
36
  trust_remote_code=True
37
  )
38
  model.eval()
 
39
 
40
  # =========================================================
41
- # Embedding Model for retrieval
42
  # =========================================================
43
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
44
 
45
  # =========================================================
46
- # Load document & chunk
47
  # =========================================================
48
  def chunk_text(text, chunk_size=300, overlap=50):
49
  words = text.split()
50
  chunks = []
51
  i = 0
52
  while i < len(words):
53
- chunks.append(" ".join(words[i:i + chunk_size]))
 
54
  i += chunk_size - overlap
55
  return chunks
56
 
@@ -82,10 +89,11 @@ def extract_final_answer(text: str) -> str:
82
  return lines[-1] if lines else text
83
 
84
  # =========================================================
85
- # Qwen inference
86
  # =========================================================
87
  def answer_question(question):
88
  context = retrieve_context(question)
 
89
  messages = [
90
  {"role": "system", "content": (
91
  "You are a strict document-based Q&A assistant.\n"
@@ -97,17 +105,23 @@ def answer_question(question):
97
  )},
98
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
99
  ]
 
100
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
101
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
102
 
103
  with torch.no_grad():
104
- output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.3, do_sample=True)
 
 
 
 
 
105
 
106
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
107
  return extract_final_answer(decoded)
108
 
109
  # =========================================================
110
- # TTS via FastAPI
111
  # =========================================================
112
  def tts_via_api(text: str, language_id="en", mode="Speak 🗣️", exaggeration=0.5, temperature=0.8, cfg_weight=0.5):
113
  payload = {
@@ -125,61 +139,60 @@ def tts_via_api(text: str, language_id="en", mode="Speak 🗣️", exaggeration=
125
  audio_b64 = data.get("audio", "")
126
  if not audio_b64:
127
  return None
128
- return f"data:audio/wav;base64,{audio_b64}"
 
 
 
 
 
 
 
 
129
  except Exception as e:
130
  print(f"TTS API error: {e}")
131
  return None
132
 
133
  # =========================================================
134
- # Chat function
135
- # =========================================================
136
- # =========================================================
137
- # Chat function
138
  # =========================================================
139
  def chat(user_message, history):
140
  if not user_message.strip():
141
  return "", history
142
 
143
  try:
144
- # 1️⃣ Generate text answer
145
  answer_text = answer_question(user_message)
146
 
147
- # 2️⃣ Generate audio
148
- audio_data = tts_via_api(answer_text)
149
 
150
- # 3️⃣ Append messages in 'messages' format
151
  history.append({"role": "user", "content": user_message})
152
- if audio_data:
153
- history.append({
154
- "role": "assistant",
155
- "content": [f"**Bot:** {answer_text}", audio_data]
156
- })
157
  else:
158
- history.append({
159
- "role": "assistant",
160
- "content": f"**Bot:** {answer_text}"
161
- })
162
 
163
  except Exception as e:
164
  print(e)
165
- history.append({
166
- "role": "assistant",
167
- "content": "**⚠️ Error generating response.**"
168
- })
169
 
170
  return "", history
171
 
172
-
173
  def reset_chat():
174
  return []
175
 
176
  # =========================================================
177
- # Build UI
178
  # =========================================================
179
  def build_ui():
180
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
181
- gr.Markdown("## 📄 Qwen Document Assistant + TTS\nAsk questions and listen to answers!")
182
- chatbot = gr.Chatbot(height=500, type="messages")
 
 
183
  msg = gr.Textbox(placeholder="Ask a question...", lines=2)
184
  send = gr.Button("Send")
185
  clear = gr.Button("🧹 Clear")
 
1
  import os
2
+ import io
3
+ import base64
4
  import requests
5
  import torch
 
6
  import numpy as np
7
+ import soundfile as sf
8
+ import gradio as gr
9
+
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from sentence_transformers import SentenceTransformer
12
 
 
17
  DOC_FILE = "general.md"
18
  MAX_NEW_TOKENS = 200
19
  TOP_K = 3
20
+ TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts" # FastAPI TTS endpoint
 
 
21
 
22
  # =========================================================
23
  # Paths
24
  # =========================================================
25
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
26
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
27
+
28
  if not os.path.exists(DOC_PATH):
29
  raise RuntimeError(f"❌ {DOC_FILE} not found next to app.py")
30
 
31
  # =========================================================
32
+ # Load Qwen Model
33
  # =========================================================
34
+ print("🔄 Loading Qwen model...")
35
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
36
+
37
  model = AutoModelForCausalLM.from_pretrained(
38
  MODEL_ID,
39
  device_map="auto",
 
41
  trust_remote_code=True
42
  )
43
  model.eval()
44
+ print("✅ Qwen model loaded.")
45
 
46
  # =========================================================
47
+ # Embedding Model
48
  # =========================================================
49
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
50
 
51
  # =========================================================
52
+ # Document Chunking
53
  # =========================================================
54
  def chunk_text(text, chunk_size=300, overlap=50):
55
  words = text.split()
56
  chunks = []
57
  i = 0
58
  while i < len(words):
59
+ chunk = words[i:i + chunk_size]
60
+ chunks.append(" ".join(chunk))
61
  i += chunk_size - overlap
62
  return chunks
63
 
 
89
  return lines[-1] if lines else text
90
 
91
  # =========================================================
92
+ # Qwen Inference
93
  # =========================================================
94
  def answer_question(question):
95
  context = retrieve_context(question)
96
+
97
  messages = [
98
  {"role": "system", "content": (
99
  "You are a strict document-based Q&A assistant.\n"
 
105
  )},
106
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
107
  ]
108
+
109
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
110
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
111
 
112
  with torch.no_grad():
113
+ output = model.generate(
114
+ **inputs,
115
+ max_new_tokens=MAX_NEW_TOKENS,
116
+ temperature=0.3,
117
+ do_sample=True
118
+ )
119
 
120
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
121
  return extract_final_answer(decoded)
122
 
123
  # =========================================================
124
+ # TTS via API (returns path to WAV for Gradio)
125
  # =========================================================
126
  def tts_via_api(text: str, language_id="en", mode="Speak 🗣️", exaggeration=0.5, temperature=0.8, cfg_weight=0.5):
127
  payload = {
 
139
  audio_b64 = data.get("audio", "")
140
  if not audio_b64:
141
  return None
142
+
143
+ # Convert base64 to WAV for Gradio
144
+ audio_bytes = base64.b64decode(audio_b64)
145
+ audio_buffer = io.BytesIO(audio_bytes)
146
+ wav, sr = sf.read(audio_buffer, dtype="float32")
147
+ temp_path = "/tmp/response.wav"
148
+ sf.write(temp_path, wav, sr)
149
+ return temp_path
150
+
151
  except Exception as e:
152
  print(f"TTS API error: {e}")
153
  return None
154
 
155
  # =========================================================
156
+ # Chat function for Gradio
 
 
 
157
  # =========================================================
158
  def chat(user_message, history):
159
  if not user_message.strip():
160
  return "", history
161
 
162
  try:
163
+ # Generate text answer
164
  answer_text = answer_question(user_message)
165
 
166
+ # Generate audio
167
+ audio_path = tts_via_api(answer_text)
168
 
169
+ # Append user message
170
  history.append({"role": "user", "content": user_message})
171
+
172
+ # Append assistant message with text + audio
173
+ if audio_path:
174
+ history.append({"role": "assistant", "content": [f"**Bot:** {answer_text}", audio_path]})
 
175
  else:
176
+ history.append({"role": "assistant", "content": f"**Bot:** {answer_text}"})
 
 
 
177
 
178
  except Exception as e:
179
  print(e)
180
+ history.append({"role": "assistant", "content": "**⚠️ Error generating response.**"})
 
 
 
181
 
182
  return "", history
183
 
 
184
  def reset_chat():
185
  return []
186
 
187
  # =========================================================
188
+ # Build Gradio UI
189
  # =========================================================
190
  def build_ui():
191
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
192
+ gr.Markdown("# 📄 Qwen Document Assistant + TTS")
193
+ gr.Markdown("Ask questions and hear the answers as audio.")
194
+
195
+ chatbot = gr.Chatbot(height=450, type="messages")
196
  msg = gr.Textbox(placeholder="Ask a question...", lines=2)
197
  send = gr.Button("Send")
198
  clear = gr.Button("🧹 Clear")