rahul7star commited on
Commit
b75f9b3
·
verified ·
1 Parent(s): 94cbcd2
Files changed (1) hide show
  1. app_qwen_tts.py +24 -40
app_qwen_tts.py CHANGED
@@ -3,21 +3,19 @@ import torch
3
  import gradio as gr
4
  import numpy as np
5
  import requests
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
- from sentence_transformers import SentenceTransformer
8
  import base64
9
  import io
 
 
 
10
 
11
  # =========================================================
12
  # Configuration
13
  # =========================================================
14
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
15
  DOC_FILE = "general.md"
16
-
17
  MAX_NEW_TOKENS = 200
18
  TOP_K = 3
19
-
20
- # FastAPI TTS endpoint
21
  TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
22
 
23
  # =========================================================
@@ -26,7 +24,7 @@ TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
26
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
27
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
28
  if not os.path.exists(DOC_PATH):
29
- raise RuntimeError(f"{DOC_FILE} not found next to app.py")
30
 
31
  # =========================================================
32
  # Load Qwen Model
@@ -41,20 +39,19 @@ model = AutoModelForCausalLM.from_pretrained(
41
  model.eval()
42
 
43
  # =========================================================
44
- # Embedding Model
45
  # =========================================================
46
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
47
 
48
  # =========================================================
49
- # Document Chunking
50
  # =========================================================
51
  def chunk_text(text, chunk_size=300, overlap=50):
52
  words = text.split()
53
  chunks = []
54
  i = 0
55
  while i < len(words):
56
- chunk = words[i:i + chunk_size]
57
- chunks.append(" ".join(chunk))
58
  i += chunk_size - overlap
59
  return chunks
60
 
@@ -65,7 +62,7 @@ DOC_CHUNKS = chunk_text(DOC_TEXT)
65
  DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, show_progress_bar=True)
66
 
67
  # =========================================================
68
- # Retrieval
69
  # =========================================================
70
  def retrieve_context(question, k=TOP_K):
71
  q_emb = embedder.encode([question], normalize_embeddings=True)
@@ -74,11 +71,11 @@ def retrieve_context(question, k=TOP_K):
74
  return "\n\n".join([DOC_CHUNKS[i] for i in top_ids])
75
 
76
  # =========================================================
77
- # Extract final answer
78
  # =========================================================
79
  def extract_final_answer(text: str) -> str:
80
  text = text.strip()
81
- markers = ["assistant:", "assistant", "answer:", "final answer:"]
82
  for m in markers:
83
  if m.lower() in text.lower():
84
  text = text.lower().split(m, 1)[-1].strip()
@@ -86,50 +83,36 @@ def extract_final_answer(text: str) -> str:
86
  return lines[-1] if lines else text
87
 
88
  # =========================================================
89
- # Qwen Inference
90
  # =========================================================
91
  def answer_question(question):
92
  context = retrieve_context(question)
93
  messages = [
94
  {"role": "system", "content": (
95
  "You are a strict document-based Q&A assistant.\n"
96
- "Answer ONLY the question.\n"
97
- "Do NOT repeat the context or the question.\n"
98
- "Respond in 1–2 sentences.\n"
99
- "If the answer is not present, say:\n"
100
- "'I could not find this information in the document.'"
101
  )},
102
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
103
  ]
104
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
105
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
106
-
107
  with torch.no_grad():
108
- output = model.generate(
109
- **inputs,
110
- max_new_tokens=MAX_NEW_TOKENS,
111
- temperature=0.3,
112
- do_sample=True
113
- )
114
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
115
  return extract_final_answer(decoded)
116
 
117
  # =========================================================
118
- # TTS via API (returns NumPy audio)
119
  # =========================================================
120
  def tts_via_api(text: str):
121
  try:
122
- payload = {"text": text}
123
- resp = requests.post(TTS_API_URL, json=payload, timeout=60)
124
  resp.raise_for_status()
125
- data = resp.json()
126
- audio_b64 = data.get("audio", "")
127
  if not audio_b64:
128
  return None
129
- # Decode base64 audio to bytes
130
  audio_bytes = base64.b64decode(audio_b64.split(",")[-1])
131
- # Convert to np.float32
132
- import soundfile as sf
133
  wav, sr = sf.read(io.BytesIO(audio_bytes), dtype='float32')
134
  return wav, sr
135
  except Exception as e:
@@ -137,26 +120,26 @@ def tts_via_api(text: str):
137
  return None
138
 
139
  # =========================================================
140
- # Gradio chat function
141
  # =========================================================
142
  def chat(user_message, history):
143
  if not user_message.strip():
144
  return "", history
145
  try:
146
- # Text answer
147
  answer_text = answer_question(user_message)
148
 
149
- # TTS
150
  tts_result = tts_via_api(answer_text)
151
  if tts_result is not None:
152
  wav, sr = tts_result
153
- # Gradio can take NumPy array + sample rate directly
154
  audio_output = (sr, wav)
155
  else:
156
  audio_output = None
157
 
158
- # Append tuple with text and playable audio
159
  history.append((user_message, [f"**Bot:** {answer_text}", audio_output]))
 
160
  except Exception as e:
161
  print(e)
162
  history.append((user_message, ["⚠️ Error generating answer or audio.", None]))
@@ -170,7 +153,8 @@ def reset_chat():
170
  # =========================================================
171
  def build_ui():
172
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
173
- gr.Markdown("## 📄 Qwen Document Assistant + TTS")
 
174
  chatbot = gr.Chatbot(height=450, type="tuples")
175
  msg = gr.Textbox(placeholder="Ask a question...", lines=2)
176
  send = gr.Button("Send")
 
3
  import gradio as gr
4
  import numpy as np
5
  import requests
 
 
6
  import base64
7
  import io
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from sentence_transformers import SentenceTransformer
10
+ import soundfile as sf
11
 
12
  # =========================================================
13
  # Configuration
14
  # =========================================================
15
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
  DOC_FILE = "general.md"
 
17
  MAX_NEW_TOKENS = 200
18
  TOP_K = 3
 
 
19
  TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
20
 
21
  # =========================================================
 
24
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
25
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
26
  if not os.path.exists(DOC_PATH):
27
+ raise RuntimeError(f"{DOC_FILE} not found")
28
 
29
  # =========================================================
30
  # Load Qwen Model
 
39
  model.eval()
40
 
41
  # =========================================================
42
+ # Embeddings
43
  # =========================================================
44
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
45
 
46
  # =========================================================
47
+ # Document chunking
48
  # =========================================================
49
  def chunk_text(text, chunk_size=300, overlap=50):
50
  words = text.split()
51
  chunks = []
52
  i = 0
53
  while i < len(words):
54
+ chunks.append(" ".join(words[i:i+chunk_size]))
 
55
  i += chunk_size - overlap
56
  return chunks
57
 
 
62
  DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, show_progress_bar=True)
63
 
64
  # =========================================================
65
+ # Retrieve context
66
  # =========================================================
67
  def retrieve_context(question, k=TOP_K):
68
  q_emb = embedder.encode([question], normalize_embeddings=True)
 
71
  return "\n\n".join([DOC_CHUNKS[i] for i in top_ids])
72
 
73
  # =========================================================
74
+ # Extract answer
75
  # =========================================================
76
  def extract_final_answer(text: str) -> str:
77
  text = text.strip()
78
+ markers = ["assistant:", "answer:", "final answer:"]
79
  for m in markers:
80
  if m.lower() in text.lower():
81
  text = text.lower().split(m, 1)[-1].strip()
 
83
  return lines[-1] if lines else text
84
 
85
  # =========================================================
86
+ # Qwen inference
87
  # =========================================================
88
  def answer_question(question):
89
  context = retrieve_context(question)
90
  messages = [
91
  {"role": "system", "content": (
92
  "You are a strict document-based Q&A assistant.\n"
93
+ "Answer ONLY the question in 1-2 sentences.\n"
94
+ "If not found, say 'I could not find this information in the document.'"
 
 
 
95
  )},
96
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
97
  ]
98
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
99
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
100
  with torch.no_grad():
101
+ output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.3, do_sample=True)
 
 
 
 
 
102
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
103
  return extract_final_answer(decoded)
104
 
105
  # =========================================================
106
+ # TTS via API
107
  # =========================================================
108
  def tts_via_api(text: str):
109
  try:
110
+ resp = requests.post(TTS_API_URL, json={"text": text}, timeout=60)
 
111
  resp.raise_for_status()
112
+ audio_b64 = resp.json().get("audio", "")
 
113
  if not audio_b64:
114
  return None
 
115
  audio_bytes = base64.b64decode(audio_b64.split(",")[-1])
 
 
116
  wav, sr = sf.read(io.BytesIO(audio_bytes), dtype='float32')
117
  return wav, sr
118
  except Exception as e:
 
120
  return None
121
 
122
  # =========================================================
123
+ # Chat function
124
  # =========================================================
125
  def chat(user_message, history):
126
  if not user_message.strip():
127
  return "", history
128
  try:
129
+ # 1️⃣ Text answer
130
  answer_text = answer_question(user_message)
131
 
132
+ # 2️⃣ Audio
133
  tts_result = tts_via_api(answer_text)
134
  if tts_result is not None:
135
  wav, sr = tts_result
 
136
  audio_output = (sr, wav)
137
  else:
138
  audio_output = None
139
 
140
+ # 3️⃣ Append nicely formatted response
141
  history.append((user_message, [f"**Bot:** {answer_text}", audio_output]))
142
+
143
  except Exception as e:
144
  print(e)
145
  history.append((user_message, ["⚠️ Error generating answer or audio.", None]))
 
153
  # =========================================================
154
  def build_ui():
155
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
156
+ gr.Markdown("# 📄 Qwen Document Assistant + TTS\nAsk a question and get a text + playable audio response.")
157
+
158
  chatbot = gr.Chatbot(height=450, type="tuples")
159
  msg = gr.Textbox(placeholder="Ask a question...", lines=2)
160
  send = gr.Button("Send")