rahul7star commited on
Commit
5e3a18d
·
verified ·
1 Parent(s): e3e9cd7

hindi support test

Browse files
Files changed (1) hide show
  1. app_qwen_tts_fast.py +36 -38
app_qwen_tts_fast.py CHANGED
@@ -16,9 +16,8 @@ MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
  DOC_FILE = "general.md"
17
  TTS_API_URL = os.getenv(
18
  "TTS_API_URL",
19
- "ETS"
20
  )
21
- print(TTS_API_URL)
22
  MAX_NEW_TOKENS = 128
23
  TOP_K = 3
24
 
@@ -30,6 +29,9 @@ SESSION = requests.Session()
30
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
31
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
32
 
 
 
 
33
  with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f:
34
  DOC_TEXT = f.read()
35
 
@@ -46,25 +48,23 @@ def chunk_text(text, chunk_size=300, overlap=50):
46
 
47
  DOC_CHUNKS = chunk_text(DOC_TEXT)
48
 
49
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
50
- DOC_EMBEDS = embedder.encode(
51
- DOC_CHUNKS, normalize_embeddings=True, batch_size=32
52
- )
53
 
54
  # =====================================================
55
- # LOAD QWEN (FAST SETTINGS)
56
  # =====================================================
57
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
58
  model = AutoModelForCausalLM.from_pretrained(
59
  MODEL_ID,
60
- device_map="auto",
61
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
62
  trust_remote_code=True
63
  )
64
  model.eval()
65
 
66
  # =====================================================
67
- # RETRIEVAL
68
  # =====================================================
69
  @lru_cache(maxsize=256)
70
  def retrieve_context(question: str):
@@ -74,7 +74,7 @@ def retrieve_context(question: str):
74
  return "\n\n".join(DOC_CHUNKS[i] for i in top_ids)
75
 
76
  # =====================================================
77
- # QWEN ANSWER (FAST)
78
  # =====================================================
79
  def answer_question(question: str) -> str:
80
  context = retrieve_context(question)
@@ -100,7 +100,7 @@ def answer_question(question: str) -> str:
100
  messages, tokenize=False, add_generation_prompt=True
101
  )
102
 
103
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
104
 
105
  with torch.no_grad():
106
  output = model.generate(
@@ -114,20 +114,19 @@ def answer_question(question: str) -> str:
114
  return decoded.split("\n")[-1].strip()
115
 
116
  # =====================================================
117
- # TTS (FAST + SAFE)
118
  # =====================================================
119
  @lru_cache(maxsize=128)
120
- def generate_audio(text: str) -> str:
121
  payload = {
122
  "text": text,
123
- "language_id": "en",
124
  "mode": "Speak 🗣️"
125
  }
126
 
127
  r = SESSION.post(TTS_API_URL, json=payload, timeout=None)
128
  r.raise_for_status()
129
 
130
- # Unique output path
131
  wav_path = f"/tmp/tts_{uuid.uuid4().hex}.wav"
132
 
133
  # Case 1: raw audio
@@ -138,34 +137,33 @@ def generate_audio(text: str) -> str:
138
 
139
  # Case 2: JSON base64
140
  data = r.json()
141
- audio_b64 = (
142
- data.get("audio")
143
- or data.get("audio_base64")
144
- or data.get("wav")
145
- )
146
-
147
  if not audio_b64:
148
- raise RuntimeError(f"TTS API returned no audio field: {data}")
149
 
150
  audio_bytes = base64.b64decode(audio_b64)
151
-
152
  with open(wav_path, "wb") as f:
153
  f.write(audio_bytes)
154
 
155
  if os.path.getsize(wav_path) < 1000:
156
- raise RuntimeError("Generated audio file is too small")
157
 
158
  return wav_path
159
 
160
  # =====================================================
161
  # MAIN PIPELINE
162
  # =====================================================
163
- def run_pipeline(question):
164
  if not question.strip():
165
  return "", None
166
 
167
  answer = answer_question(question)
168
- audio_path = generate_audio(answer)
 
 
 
 
 
169
 
170
  return f"**Bot:** {answer}", audio_path
171
 
@@ -173,31 +171,31 @@ def run_pipeline(question):
173
  # UI
174
  # =====================================================
175
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
176
-
177
 
178
  with gr.Row():
179
- with gr.Column():
180
  user_input = gr.Textbox(
181
  label="Your Question",
182
  placeholder="Who is CEO of OhamLab?",
183
  lines=3
184
  )
 
 
 
 
 
185
  ask_btn = gr.Button("Ask")
186
 
187
- with gr.Column():
188
  answer_text = gr.Markdown()
189
  answer_audio = gr.Audio(type="filepath")
190
 
191
  ask_btn.click(
192
  fn=run_pipeline,
193
- inputs=user_input,
194
  outputs=[answer_text, answer_audio]
195
  )
196
 
197
- demo.queue() # enable long-running jobs (5 min audio OK)
198
-
199
- demo.launch(
200
- server_name="0.0.0.0",
201
- server_port=7860,
202
- share=False
203
- )
 
16
  DOC_FILE = "general.md"
17
  TTS_API_URL = os.getenv(
18
  "TTS_API_URL",
19
+ "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
20
  )
 
21
  MAX_NEW_TOKENS = 128
22
  TOP_K = 3
23
 
 
29
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
30
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
31
 
32
+ if not os.path.exists(DOC_PATH):
33
+ raise RuntimeError(f"{DOC_FILE} not found")
34
+
35
  with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f:
36
  DOC_TEXT = f.read()
37
 
 
48
 
49
  DOC_CHUNKS = chunk_text(DOC_TEXT)
50
 
51
+ embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
52
+ DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, batch_size=32)
 
 
53
 
54
  # =====================================================
55
+ # LOAD QWEN MODEL (CPU only)
56
  # =====================================================
57
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
58
  model = AutoModelForCausalLM.from_pretrained(
59
  MODEL_ID,
60
+ device_map="cpu", # strictly CPU
61
+ torch_dtype=torch.float32,
62
  trust_remote_code=True
63
  )
64
  model.eval()
65
 
66
  # =====================================================
67
+ # RETRIEVAL WITH CACHE
68
  # =====================================================
69
  @lru_cache(maxsize=256)
70
  def retrieve_context(question: str):
 
74
  return "\n\n".join(DOC_CHUNKS[i] for i in top_ids)
75
 
76
  # =====================================================
77
+ # QWEN ANSWER (CPU optimized)
78
  # =====================================================
79
  def answer_question(question: str) -> str:
80
  context = retrieve_context(question)
 
100
  messages, tokenize=False, add_generation_prompt=True
101
  )
102
 
103
+ inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
104
 
105
  with torch.no_grad():
106
  output = model.generate(
 
114
  return decoded.split("\n")[-1].strip()
115
 
116
  # =====================================================
117
+ # TTS (CPU safe, flexible language)
118
  # =====================================================
119
  @lru_cache(maxsize=128)
120
+ def generate_audio(text: str, language_id: str = "en") -> str:
121
  payload = {
122
  "text": text,
123
+ "language_id": language_id,
124
  "mode": "Speak 🗣️"
125
  }
126
 
127
  r = SESSION.post(TTS_API_URL, json=payload, timeout=None)
128
  r.raise_for_status()
129
 
 
130
  wav_path = f"/tmp/tts_{uuid.uuid4().hex}.wav"
131
 
132
  # Case 1: raw audio
 
137
 
138
  # Case 2: JSON base64
139
  data = r.json()
140
+ audio_b64 = data.get("audio") or data.get("audio_base64") or data.get("wav")
 
 
 
 
 
141
  if not audio_b64:
142
+ raise RuntimeError(f"TTS API returned no audio: {data}")
143
 
144
  audio_bytes = base64.b64decode(audio_b64)
 
145
  with open(wav_path, "wb") as f:
146
  f.write(audio_bytes)
147
 
148
  if os.path.getsize(wav_path) < 1000:
149
+ raise RuntimeError("Generated audio file too small")
150
 
151
  return wav_path
152
 
153
  # =====================================================
154
  # MAIN PIPELINE
155
  # =====================================================
156
+ def run_pipeline(question: str, language_id: str):
157
  if not question.strip():
158
  return "", None
159
 
160
  answer = answer_question(question)
161
+
162
+ try:
163
+ audio_path = generate_audio(answer, language_id)
164
+ except Exception as e:
165
+ print("TTS generation failed:", e)
166
+ audio_path = None
167
 
168
  return f"**Bot:** {answer}", audio_path
169
 
 
171
  # UI
172
  # =====================================================
173
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
174
+ gr.Markdown("# 📄 Qwen CPU Assistant + TTS")
175
 
176
  with gr.Row():
177
+ with gr.Column(scale=1):
178
  user_input = gr.Textbox(
179
  label="Your Question",
180
  placeholder="Who is CEO of OhamLab?",
181
  lines=3
182
  )
183
+ language_dropdown = gr.Dropdown(
184
+ label="TTS Language",
185
+ choices=["en", "hi"],
186
+ value="en"
187
+ )
188
  ask_btn = gr.Button("Ask")
189
 
190
+ with gr.Column(scale=1):
191
  answer_text = gr.Markdown()
192
  answer_audio = gr.Audio(type="filepath")
193
 
194
  ask_btn.click(
195
  fn=run_pipeline,
196
+ inputs=[user_input, language_dropdown],
197
  outputs=[answer_text, answer_audio]
198
  )
199
 
200
+ demo.queue() # long-running jobs OK (up to 5 min audio)
201
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)