rahul7star commited on
Commit
94cbcd2
·
verified ·
1 Parent(s): d153534

Update app_qwen_tts.py

Browse files
Files changed (1) hide show
  1. app_qwen_tts.py +40 -54
app_qwen_tts.py CHANGED
@@ -2,9 +2,11 @@ import os
2
  import torch
3
  import gradio as gr
4
  import numpy as np
5
- import soundfile as sf
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from sentence_transformers import SentenceTransformer
 
 
8
 
9
  # =========================================================
10
  # Configuration
@@ -14,24 +16,22 @@ DOC_FILE = "general.md"
14
 
15
  MAX_NEW_TOKENS = 200
16
  TOP_K = 3
17
- TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts" # FastAPI TTS endpoint
 
 
18
 
19
  # =========================================================
20
  # Paths
21
  # =========================================================
22
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
23
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
24
-
25
  if not os.path.exists(DOC_PATH):
26
  raise RuntimeError(f"❌ {DOC_FILE} not found next to app.py")
27
 
28
  # =========================================================
29
  # Load Qwen Model
30
  # =========================================================
31
- tokenizer = AutoTokenizer.from_pretrained(
32
- MODEL_ID, trust_remote_code=True
33
- )
34
-
35
  model = AutoModelForCausalLM.from_pretrained(
36
  MODEL_ID,
37
  device_map="auto",
@@ -90,25 +90,17 @@ def extract_final_answer(text: str) -> str:
90
  # =========================================================
91
  def answer_question(question):
92
  context = retrieve_context(question)
93
-
94
  messages = [
95
- {
96
- "role": "system",
97
- "content": (
98
- "You are a strict document-based Q&A assistant.\n"
99
- "Answer ONLY the question.\n"
100
- "Do NOT repeat the context or the question.\n"
101
- "Respond in 1–2 sentences.\n"
102
- "If the answer is not present, say:\n"
103
- "'I could not find this information in the document.'"
104
- )
105
- },
106
- {
107
- "role": "user",
108
- "content": f"Context:\n{context}\n\nQuestion:\n{question}"
109
- }
110
  ]
111
-
112
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
113
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
114
 
@@ -119,68 +111,62 @@ def answer_question(question):
119
  temperature=0.3,
120
  do_sample=True
121
  )
122
-
123
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
124
  return extract_final_answer(decoded)
125
 
126
  # =========================================================
127
- # CPU-friendly TTS
128
-
129
  # =========================================================
130
- def tts_via_api(text: str, language_id="en", mode="Speak 🗣️", exaggeration=0.5, temperature=0.8, cfg_weight=0.5):
131
- payload = {
132
- "text": text,
133
- "language_id": language_id,
134
- "mode": mode,
135
- "exaggeration": exaggeration,
136
- "temperature": temperature,
137
- "cfg_weight": cfg_weight
138
- }
139
  try:
 
140
  resp = requests.post(TTS_API_URL, json=payload, timeout=60)
141
  resp.raise_for_status()
142
  data = resp.json()
143
  audio_b64 = data.get("audio", "")
144
  if not audio_b64:
145
  return None
146
- return f"data:audio/wav;base64,{audio_b64}"
 
 
 
 
 
147
  except Exception as e:
148
  print(f"TTS API error: {e}")
149
  return None
150
 
151
  # =========================================================
152
- # Gradio Chatbot function
153
  # =========================================================
154
  def chat(user_message, history):
155
  if not user_message.strip():
156
  return "", history
157
-
158
  try:
159
- # 1️⃣ Generate text answer
160
  answer_text = answer_question(user_message)
161
 
162
- # 2️⃣ Generate audio
163
- sr, wav = tts_via_api(answer_text)
164
-
165
- # Save temp wav for Gradio audio player
166
- audio_path = "/tmp/output.wav"
167
- import soundfile as sf
168
- sf.write(audio_path, wav, sr)
169
-
170
- # 3️⃣ Append as tuple (text + audio)
171
- history.append((user_message, [answer_text, audio_path]))
172
-
173
  except Exception as e:
174
  print(e)
175
- history.append((user_message, ["⚠️ Error generating answer or audio."]))
176
-
177
  return "", history
178
 
179
  def reset_chat():
180
  return []
181
 
182
  # =========================================================
183
- # Build UI
184
  # =========================================================
185
  def build_ui():
186
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
2
  import torch
3
  import gradio as gr
4
  import numpy as np
5
+ import requests
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from sentence_transformers import SentenceTransformer
8
+ import base64
9
+ import io
10
 
11
  # =========================================================
12
  # Configuration
 
16
 
17
  MAX_NEW_TOKENS = 200
18
  TOP_K = 3
19
+
20
+ # FastAPI TTS endpoint
21
+ TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
22
 
23
  # =========================================================
24
  # Paths
25
  # =========================================================
26
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
27
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
 
28
  if not os.path.exists(DOC_PATH):
29
  raise RuntimeError(f"❌ {DOC_FILE} not found next to app.py")
30
 
31
  # =========================================================
32
  # Load Qwen Model
33
  # =========================================================
34
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
35
  model = AutoModelForCausalLM.from_pretrained(
36
  MODEL_ID,
37
  device_map="auto",
 
90
  # =========================================================
91
  def answer_question(question):
92
  context = retrieve_context(question)
 
93
  messages = [
94
+ {"role": "system", "content": (
95
+ "You are a strict document-based Q&A assistant.\n"
96
+ "Answer ONLY the question.\n"
97
+ "Do NOT repeat the context or the question.\n"
98
+ "Respond in 1–2 sentences.\n"
99
+ "If the answer is not present, say:\n"
100
+ "'I could not find this information in the document.'"
101
+ )},
102
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
 
 
 
 
 
 
103
  ]
 
104
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
105
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
106
 
 
111
  temperature=0.3,
112
  do_sample=True
113
  )
 
114
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
115
  return extract_final_answer(decoded)
116
 
117
  # =========================================================
118
+ # TTS via API (returns NumPy audio)
 
119
  # =========================================================
120
+ def tts_via_api(text: str):
 
 
 
 
 
 
 
 
121
  try:
122
+ payload = {"text": text}
123
  resp = requests.post(TTS_API_URL, json=payload, timeout=60)
124
  resp.raise_for_status()
125
  data = resp.json()
126
  audio_b64 = data.get("audio", "")
127
  if not audio_b64:
128
  return None
129
+ # Decode base64 audio to bytes
130
+ audio_bytes = base64.b64decode(audio_b64.split(",")[-1])
131
+ # Convert to np.float32
132
+ import soundfile as sf
133
+ wav, sr = sf.read(io.BytesIO(audio_bytes), dtype='float32')
134
+ return wav, sr
135
  except Exception as e:
136
  print(f"TTS API error: {e}")
137
  return None
138
 
139
  # =========================================================
140
+ # Gradio chat function
141
  # =========================================================
142
  def chat(user_message, history):
143
  if not user_message.strip():
144
  return "", history
 
145
  try:
146
+ # Text answer
147
  answer_text = answer_question(user_message)
148
 
149
+ # TTS
150
+ tts_result = tts_via_api(answer_text)
151
+ if tts_result is not None:
152
+ wav, sr = tts_result
153
+ # Gradio can take NumPy array + sample rate directly
154
+ audio_output = (sr, wav)
155
+ else:
156
+ audio_output = None
157
+
158
+ # Append tuple with text and playable audio
159
+ history.append((user_message, [f"**Bot:** {answer_text}", audio_output]))
160
  except Exception as e:
161
  print(e)
162
+ history.append((user_message, ["⚠️ Error generating answer or audio.", None]))
 
163
  return "", history
164
 
165
  def reset_chat():
166
  return []
167
 
168
  # =========================================================
169
+ # Gradio UI
170
  # =========================================================
171
  def build_ui():
172
  with gr.Blocks(theme=gr.themes.Soft()) as demo: