rahul7star commited on
Commit
e3e9cd7
·
verified ·
1 Parent(s): e3405bf

back before optim1

Browse files
Files changed (1) hide show
  1. app_qwen_tts_fast.py +34 -69
app_qwen_tts_fast.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  import gradio as gr
7
  import numpy as np
8
  from functools import lru_cache
9
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
10
  from sentence_transformers import SentenceTransformer
11
 
12
  # =====================================================
@@ -14,15 +14,13 @@ from sentence_transformers import SentenceTransformer
14
  # =====================================================
15
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
  DOC_FILE = "general.md"
17
-
18
  TTS_API_URL = os.getenv(
19
  "TTS_API_URL",
20
- "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
21
  )
22
-
23
- MAX_NEW_TOKENS = 80 # 🔥 shorter = faster
24
  TOP_K = 3
25
- MIN_RELEVANCE_SCORE = 0.35 # 🔒 anti-hallucination
26
 
27
  SESSION = requests.Session()
28
 
@@ -50,73 +48,46 @@ DOC_CHUNKS = chunk_text(DOC_TEXT)
50
 
51
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
52
  DOC_EMBEDS = embedder.encode(
53
- DOC_CHUNKS,
54
- normalize_embeddings=True,
55
- batch_size=32
56
  )
57
 
58
  # =====================================================
59
- # LOAD QWEN (QUANTIZED IF POSSIBLE)
60
  # =====================================================
61
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
62
-
63
- bnb_config = None
64
- if torch.cuda.is_available():
65
- try:
66
- bnb_config = BitsAndBytesConfig(
67
- load_in_4bit=True,
68
- bnb_4bit_compute_dtype=torch.float16,
69
- bnb_4bit_use_double_quant=True,
70
- bnb_4bit_quant_type="nf4",
71
- )
72
- print("✅ Using 4-bit quantization")
73
- except Exception:
74
- print("⚠️ bitsandbytes not available, loading normal model")
75
-
76
  model = AutoModelForCausalLM.from_pretrained(
77
  MODEL_ID,
78
  device_map="auto",
79
- quantization_config=bnb_config,
80
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
81
  trust_remote_code=True
82
  )
83
  model.eval()
84
 
85
  # =====================================================
86
- # RETRIEVAL (STRICT)
87
  # =====================================================
88
  @lru_cache(maxsize=256)
89
  def retrieve_context(question: str):
90
  q_emb = embedder.encode([question], normalize_embeddings=True)
91
  scores = np.dot(DOC_EMBEDS, q_emb[0])
92
-
93
  top_ids = scores.argsort()[-TOP_K:][::-1]
94
- top_score = scores[top_ids[0]]
95
-
96
- if top_score < MIN_RELEVANCE_SCORE:
97
- return None
98
-
99
  return "\n\n".join(DOC_CHUNKS[i] for i in top_ids)
100
 
101
  # =====================================================
102
- # ANSWER (NO HALLUCINATION)
103
  # =====================================================
104
- def answer_question(question: str):
105
  context = retrieve_context(question)
106
 
107
- # 🚨 Abort early
108
- if context is None:
109
- return None
110
-
111
  messages = [
112
  {
113
  "role": "system",
114
  "content": (
115
- "You are a STRICT document-based assistant.\n"
116
- "ONLY answer if the information is explicitly present.\n"
117
- "If not found, reply EXACTLY:\n"
118
- "'I could not find this information in the document.'\n"
119
- "Do NOT explain. Do NOT guess."
120
  )
121
  },
122
  {
@@ -136,23 +107,17 @@ def answer_question(question: str):
136
  **inputs,
137
  max_new_tokens=MAX_NEW_TOKENS,
138
  do_sample=False,
139
- temperature=0.0,
140
  use_cache=True
141
  )
142
 
143
- decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip()
144
- final = decoded.split("\n")[-1].strip()
145
-
146
- if "could not find this information" in final.lower():
147
- return None
148
-
149
- return final
150
 
151
  # =====================================================
152
- # TTS (SAFE + CACHED)
153
  # =====================================================
154
  @lru_cache(maxsize=128)
155
- def generate_audio(text: str):
156
  payload = {
157
  "text": text,
158
  "language_id": "en",
@@ -162,13 +127,16 @@ def generate_audio(text: str):
162
  r = SESSION.post(TTS_API_URL, json=payload, timeout=None)
163
  r.raise_for_status()
164
 
 
165
  wav_path = f"/tmp/tts_{uuid.uuid4().hex}.wav"
166
 
 
167
  if r.headers.get("content-type", "").startswith("audio"):
168
  with open(wav_path, "wb") as f:
169
  f.write(r.content)
170
  return wav_path
171
 
 
172
  data = r.json()
173
  audio_b64 = (
174
  data.get("audio")
@@ -177,41 +145,38 @@ def generate_audio(text: str):
177
  )
178
 
179
  if not audio_b64:
180
- raise RuntimeError(f"TTS API returned no audio: {data}")
 
 
181
 
182
  with open(wav_path, "wb") as f:
183
- f.write(base64.b64decode(audio_b64))
184
 
185
  if os.path.getsize(wav_path) < 1000:
186
- raise RuntimeError("Generated audio file is empty")
187
 
188
  return wav_path
189
 
190
  # =====================================================
191
- # PIPELINE
192
  # =====================================================
193
  def run_pipeline(question):
194
  if not question.strip():
195
  return "", None
196
 
197
  answer = answer_question(question)
198
-
199
- # 🚨 FAST EXIT — NO AUDIO
200
- if answer is None:
201
- return (
202
- "**Bot:** I could not find this information in the document.",
203
- None
204
- )
205
-
206
  audio_path = generate_audio(answer)
 
207
  return f"**Bot:** {answer}", audio_path
208
 
209
  # =====================================================
210
  # UI
211
  # =====================================================
212
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
 
213
  with gr.Row():
214
- with gr.Column(scale=1):
215
  user_input = gr.Textbox(
216
  label="Your Question",
217
  placeholder="Who is CEO of OhamLab?",
@@ -219,9 +184,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
219
  )
220
  ask_btn = gr.Button("Ask")
221
 
222
- with gr.Column(scale=1):
223
  answer_text = gr.Markdown()
224
- answer_audio = gr.Audio(type="filepath", label="Assistant Voice")
225
 
226
  ask_btn.click(
227
  fn=run_pipeline,
@@ -229,7 +194,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
229
  outputs=[answer_text, answer_audio]
230
  )
231
 
232
- demo.queue()
233
 
234
  demo.launch(
235
  server_name="0.0.0.0",
 
6
  import gradio as gr
7
  import numpy as np
8
  from functools import lru_cache
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from sentence_transformers import SentenceTransformer
11
 
12
  # =====================================================
 
14
  # =====================================================
15
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
  DOC_FILE = "general.md"
 
17
  TTS_API_URL = os.getenv(
18
  "TTS_API_URL",
19
+ "ETS"
20
  )
21
+ print(TTS_API_URL)
22
+ MAX_NEW_TOKENS = 128
23
  TOP_K = 3
 
24
 
25
  SESSION = requests.Session()
26
 
 
48
 
49
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
50
  DOC_EMBEDS = embedder.encode(
51
+ DOC_CHUNKS, normalize_embeddings=True, batch_size=32
 
 
52
  )
53
 
54
  # =====================================================
55
+ # LOAD QWEN (FAST SETTINGS)
56
  # =====================================================
57
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  model = AutoModelForCausalLM.from_pretrained(
59
  MODEL_ID,
60
  device_map="auto",
 
61
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
62
  trust_remote_code=True
63
  )
64
  model.eval()
65
 
66
  # =====================================================
67
+ # RETRIEVAL
68
  # =====================================================
69
  @lru_cache(maxsize=256)
70
  def retrieve_context(question: str):
71
  q_emb = embedder.encode([question], normalize_embeddings=True)
72
  scores = np.dot(DOC_EMBEDS, q_emb[0])
 
73
  top_ids = scores.argsort()[-TOP_K:][::-1]
 
 
 
 
 
74
  return "\n\n".join(DOC_CHUNKS[i] for i in top_ids)
75
 
76
  # =====================================================
77
+ # QWEN ANSWER (FAST)
78
  # =====================================================
79
+ def answer_question(question: str) -> str:
80
  context = retrieve_context(question)
81
 
 
 
 
 
82
  messages = [
83
  {
84
  "role": "system",
85
  "content": (
86
+ "You are a strict document-based Q&A assistant.\n"
87
+ "Answer ONLY the question.\n"
88
+ "Respond in 1 short sentence.\n"
89
+ "If not found, say:\n"
90
+ "'I could not find this information in the document.'"
91
  )
92
  },
93
  {
 
107
  **inputs,
108
  max_new_tokens=MAX_NEW_TOKENS,
109
  do_sample=False,
 
110
  use_cache=True
111
  )
112
 
113
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
114
+ return decoded.split("\n")[-1].strip()
 
 
 
 
 
115
 
116
  # =====================================================
117
+ # TTS (FAST + SAFE)
118
  # =====================================================
119
  @lru_cache(maxsize=128)
120
+ def generate_audio(text: str) -> str:
121
  payload = {
122
  "text": text,
123
  "language_id": "en",
 
127
  r = SESSION.post(TTS_API_URL, json=payload, timeout=None)
128
  r.raise_for_status()
129
 
130
+ # Unique output path
131
  wav_path = f"/tmp/tts_{uuid.uuid4().hex}.wav"
132
 
133
+ # Case 1: raw audio
134
  if r.headers.get("content-type", "").startswith("audio"):
135
  with open(wav_path, "wb") as f:
136
  f.write(r.content)
137
  return wav_path
138
 
139
+ # Case 2: JSON base64
140
  data = r.json()
141
  audio_b64 = (
142
  data.get("audio")
 
145
  )
146
 
147
  if not audio_b64:
148
+ raise RuntimeError(f"TTS API returned no audio field: {data}")
149
+
150
+ audio_bytes = base64.b64decode(audio_b64)
151
 
152
  with open(wav_path, "wb") as f:
153
+ f.write(audio_bytes)
154
 
155
  if os.path.getsize(wav_path) < 1000:
156
+ raise RuntimeError("Generated audio file is too small")
157
 
158
  return wav_path
159
 
160
  # =====================================================
161
+ # MAIN PIPELINE
162
  # =====================================================
163
  def run_pipeline(question):
164
  if not question.strip():
165
  return "", None
166
 
167
  answer = answer_question(question)
 
 
 
 
 
 
 
 
168
  audio_path = generate_audio(answer)
169
+
170
  return f"**Bot:** {answer}", audio_path
171
 
172
  # =====================================================
173
  # UI
174
  # =====================================================
175
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
176
+
177
+
178
  with gr.Row():
179
+ with gr.Column():
180
  user_input = gr.Textbox(
181
  label="Your Question",
182
  placeholder="Who is CEO of OhamLab?",
 
184
  )
185
  ask_btn = gr.Button("Ask")
186
 
187
+ with gr.Column():
188
  answer_text = gr.Markdown()
189
+ answer_audio = gr.Audio(type="filepath")
190
 
191
  ask_btn.click(
192
  fn=run_pipeline,
 
194
  outputs=[answer_text, answer_audio]
195
  )
196
 
197
+ demo.queue() # enable long-running jobs (5 min audio OK)
198
 
199
  demo.launch(
200
  server_name="0.0.0.0",