rahul7star commited on
Commit
e3405bf
·
verified ·
1 Parent(s): f7d140b
Files changed (1) hide show
  1. app_qwen_tts_fast.py +69 -35
app_qwen_tts_fast.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  import gradio as gr
7
  import numpy as np
8
  from functools import lru_cache
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from sentence_transformers import SentenceTransformer
11
 
12
  # =====================================================
@@ -14,13 +14,15 @@ from sentence_transformers import SentenceTransformer
14
  # =====================================================
15
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
  DOC_FILE = "general.md"
 
17
  TTS_API_URL = os.getenv(
18
  "TTS_API_URL",
19
- "ETS"
20
  )
21
- print(TTS_API_URL)
22
- MAX_NEW_TOKENS = 128
23
  TOP_K = 3
 
24
 
25
  SESSION = requests.Session()
26
 
@@ -48,46 +50,73 @@ DOC_CHUNKS = chunk_text(DOC_TEXT)
48
 
49
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
50
  DOC_EMBEDS = embedder.encode(
51
- DOC_CHUNKS, normalize_embeddings=True, batch_size=32
 
 
52
  )
53
 
54
  # =====================================================
55
- # LOAD QWEN (FAST SETTINGS)
56
  # =====================================================
57
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  model = AutoModelForCausalLM.from_pretrained(
59
  MODEL_ID,
60
  device_map="auto",
 
61
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
62
  trust_remote_code=True
63
  )
64
  model.eval()
65
 
66
  # =====================================================
67
- # RETRIEVAL
68
  # =====================================================
69
  @lru_cache(maxsize=256)
70
  def retrieve_context(question: str):
71
  q_emb = embedder.encode([question], normalize_embeddings=True)
72
  scores = np.dot(DOC_EMBEDS, q_emb[0])
 
73
  top_ids = scores.argsort()[-TOP_K:][::-1]
 
 
 
 
 
74
  return "\n\n".join(DOC_CHUNKS[i] for i in top_ids)
75
 
76
  # =====================================================
77
- # QWEN ANSWER (FAST)
78
  # =====================================================
79
- def answer_question(question: str) -> str:
80
  context = retrieve_context(question)
81
 
 
 
 
 
82
  messages = [
83
  {
84
  "role": "system",
85
  "content": (
86
- "You are a strict document-based Q&A assistant.\n"
87
- "Answer ONLY the question.\n"
88
- "Respond in 1 short sentence.\n"
89
- "If not found, say:\n"
90
- "'I could not find this information in the document.'"
91
  )
92
  },
93
  {
@@ -107,17 +136,23 @@ def answer_question(question: str) -> str:
107
  **inputs,
108
  max_new_tokens=MAX_NEW_TOKENS,
109
  do_sample=False,
 
110
  use_cache=True
111
  )
112
 
113
- decoded = tokenizer.decode(output[0], skip_special_tokens=True)
114
- return decoded.split("\n")[-1].strip()
 
 
 
 
 
115
 
116
  # =====================================================
117
- # TTS (FAST + SAFE)
118
  # =====================================================
119
  @lru_cache(maxsize=128)
120
- def generate_audio(text: str) -> str:
121
  payload = {
122
  "text": text,
123
  "language_id": "en",
@@ -127,16 +162,13 @@ def generate_audio(text: str) -> str:
127
  r = SESSION.post(TTS_API_URL, json=payload, timeout=None)
128
  r.raise_for_status()
129
 
130
- # Unique output path
131
  wav_path = f"/tmp/tts_{uuid.uuid4().hex}.wav"
132
 
133
- # Case 1: raw audio
134
  if r.headers.get("content-type", "").startswith("audio"):
135
  with open(wav_path, "wb") as f:
136
  f.write(r.content)
137
  return wav_path
138
 
139
- # Case 2: JSON base64
140
  data = r.json()
141
  audio_b64 = (
142
  data.get("audio")
@@ -145,38 +177,41 @@ def generate_audio(text: str) -> str:
145
  )
146
 
147
  if not audio_b64:
148
- raise RuntimeError(f"TTS API returned no audio field: {data}")
149
-
150
- audio_bytes = base64.b64decode(audio_b64)
151
 
152
  with open(wav_path, "wb") as f:
153
- f.write(audio_bytes)
154
 
155
  if os.path.getsize(wav_path) < 1000:
156
- raise RuntimeError("Generated audio file is too small")
157
 
158
  return wav_path
159
 
160
  # =====================================================
161
- # MAIN PIPELINE
162
  # =====================================================
163
  def run_pipeline(question):
164
  if not question.strip():
165
  return "", None
166
 
167
  answer = answer_question(question)
168
- audio_path = generate_audio(answer)
169
 
 
 
 
 
 
 
 
 
170
  return f"**Bot:** {answer}", audio_path
171
 
172
  # =====================================================
173
  # UI
174
  # =====================================================
175
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
176
-
177
-
178
  with gr.Row():
179
- with gr.Column():
180
  user_input = gr.Textbox(
181
  label="Your Question",
182
  placeholder="Who is CEO of OhamLab?",
@@ -184,9 +219,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
184
  )
185
  ask_btn = gr.Button("Ask")
186
 
187
- with gr.Column():
188
  answer_text = gr.Markdown()
189
- answer_audio = gr.Audio(type="filepath")
190
 
191
  ask_btn.click(
192
  fn=run_pipeline,
@@ -194,11 +229,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
194
  outputs=[answer_text, answer_audio]
195
  )
196
 
197
- demo.queue() # enable long-running jobs (5 min audio OK)
198
 
199
  demo.launch(
200
  server_name="0.0.0.0",
201
  server_port=7860,
202
  share=False
203
  )
204
-
 
6
  import gradio as gr
7
  import numpy as np
8
  from functools import lru_cache
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
10
  from sentence_transformers import SentenceTransformer
11
 
12
  # =====================================================
 
14
  # =====================================================
15
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
  DOC_FILE = "general.md"
17
+
18
  TTS_API_URL = os.getenv(
19
  "TTS_API_URL",
20
+ "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
21
  )
22
+
23
+ MAX_NEW_TOKENS = 80 # 🔥 shorter = faster
24
  TOP_K = 3
25
+ MIN_RELEVANCE_SCORE = 0.35 # 🔒 anti-hallucination
26
 
27
  SESSION = requests.Session()
28
 
 
50
 
51
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
52
  DOC_EMBEDS = embedder.encode(
53
+ DOC_CHUNKS,
54
+ normalize_embeddings=True,
55
+ batch_size=32
56
  )
57
 
58
  # =====================================================
59
+ # LOAD QWEN (QUANTIZED IF POSSIBLE)
60
  # =====================================================
61
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
62
+
63
+ bnb_config = None
64
+ if torch.cuda.is_available():
65
+ try:
66
+ bnb_config = BitsAndBytesConfig(
67
+ load_in_4bit=True,
68
+ bnb_4bit_compute_dtype=torch.float16,
69
+ bnb_4bit_use_double_quant=True,
70
+ bnb_4bit_quant_type="nf4",
71
+ )
72
+ print("✅ Using 4-bit quantization")
73
+ except Exception:
74
+ print("⚠️ bitsandbytes not available, loading normal model")
75
+
76
  model = AutoModelForCausalLM.from_pretrained(
77
  MODEL_ID,
78
  device_map="auto",
79
+ quantization_config=bnb_config,
80
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
81
  trust_remote_code=True
82
  )
83
  model.eval()
84
 
85
  # =====================================================
86
+ # RETRIEVAL (STRICT)
87
  # =====================================================
88
  @lru_cache(maxsize=256)
89
  def retrieve_context(question: str):
90
  q_emb = embedder.encode([question], normalize_embeddings=True)
91
  scores = np.dot(DOC_EMBEDS, q_emb[0])
92
+
93
  top_ids = scores.argsort()[-TOP_K:][::-1]
94
+ top_score = scores[top_ids[0]]
95
+
96
+ if top_score < MIN_RELEVANCE_SCORE:
97
+ return None
98
+
99
  return "\n\n".join(DOC_CHUNKS[i] for i in top_ids)
100
 
101
  # =====================================================
102
+ # ANSWER (NO HALLUCINATION)
103
  # =====================================================
104
+ def answer_question(question: str):
105
  context = retrieve_context(question)
106
 
107
+ # 🚨 Abort early
108
+ if context is None:
109
+ return None
110
+
111
  messages = [
112
  {
113
  "role": "system",
114
  "content": (
115
+ "You are a STRICT document-based assistant.\n"
116
+ "ONLY answer if the information is explicitly present.\n"
117
+ "If not found, reply EXACTLY:\n"
118
+ "'I could not find this information in the document.'\n"
119
+ "Do NOT explain. Do NOT guess."
120
  )
121
  },
122
  {
 
136
  **inputs,
137
  max_new_tokens=MAX_NEW_TOKENS,
138
  do_sample=False,
139
+ temperature=0.0,
140
  use_cache=True
141
  )
142
 
143
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip()
144
+ final = decoded.split("\n")[-1].strip()
145
+
146
+ if "could not find this information" in final.lower():
147
+ return None
148
+
149
+ return final
150
 
151
  # =====================================================
152
+ # TTS (SAFE + CACHED)
153
  # =====================================================
154
  @lru_cache(maxsize=128)
155
+ def generate_audio(text: str):
156
  payload = {
157
  "text": text,
158
  "language_id": "en",
 
162
  r = SESSION.post(TTS_API_URL, json=payload, timeout=None)
163
  r.raise_for_status()
164
 
 
165
  wav_path = f"/tmp/tts_{uuid.uuid4().hex}.wav"
166
 
 
167
  if r.headers.get("content-type", "").startswith("audio"):
168
  with open(wav_path, "wb") as f:
169
  f.write(r.content)
170
  return wav_path
171
 
 
172
  data = r.json()
173
  audio_b64 = (
174
  data.get("audio")
 
177
  )
178
 
179
  if not audio_b64:
180
+ raise RuntimeError(f"TTS API returned no audio: {data}")
 
 
181
 
182
  with open(wav_path, "wb") as f:
183
+ f.write(base64.b64decode(audio_b64))
184
 
185
  if os.path.getsize(wav_path) < 1000:
186
+ raise RuntimeError("Generated audio file is empty")
187
 
188
  return wav_path
189
 
190
  # =====================================================
191
+ # PIPELINE
192
  # =====================================================
193
  def run_pipeline(question):
194
  if not question.strip():
195
  return "", None
196
 
197
  answer = answer_question(question)
 
198
 
199
+ # 🚨 FAST EXIT — NO AUDIO
200
+ if answer is None:
201
+ return (
202
+ "**Bot:** I could not find this information in the document.",
203
+ None
204
+ )
205
+
206
+ audio_path = generate_audio(answer)
207
  return f"**Bot:** {answer}", audio_path
208
 
209
  # =====================================================
210
  # UI
211
  # =====================================================
212
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
 
213
  with gr.Row():
214
+ with gr.Column(scale=1):
215
  user_input = gr.Textbox(
216
  label="Your Question",
217
  placeholder="Who is CEO of OhamLab?",
 
219
  )
220
  ask_btn = gr.Button("Ask")
221
 
222
+ with gr.Column(scale=1):
223
  answer_text = gr.Markdown()
224
+ answer_audio = gr.Audio(type="filepath", label="Assistant Voice")
225
 
226
  ask_btn.click(
227
  fn=run_pipeline,
 
229
  outputs=[answer_text, answer_audio]
230
  )
231
 
232
+ demo.queue()
233
 
234
  demo.launch(
235
  server_name="0.0.0.0",
236
  server_port=7860,
237
  share=False
238
  )