rahul7star commited on
Commit
e96cf4a
·
verified ·
1 Parent(s): 7bdcb50

multi-lang support test

Browse files
Files changed (1) hide show
  1. app_qwen_tts_fast.py +41 -38
app_qwen_tts_fast.py CHANGED
@@ -13,27 +13,32 @@ from sentence_transformers import SentenceTransformer
13
  # CONFIG
14
  # =====================================================
15
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
- DOC_FILE = "general.md"
 
17
  TTS_API_URL = os.getenv(
18
  "TTS_API_URL",
19
  "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
20
  )
21
  MAX_NEW_TOKENS = 128
22
  TOP_K = 3
23
-
24
  SESSION = requests.Session()
25
 
26
  # =====================================================
27
- # LOAD DOCUMENT
28
  # =====================================================
29
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
30
- DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
 
 
 
 
 
31
 
32
- if not os.path.exists(DOC_PATH):
33
- raise RuntimeError(f"{DOC_FILE} not found")
34
 
35
- with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f:
36
- DOC_TEXT = f.read()
37
 
38
  # =====================================================
39
  # CHUNK + EMBED
@@ -46,10 +51,12 @@ def chunk_text(text, chunk_size=300, overlap=50):
46
  i += chunk_size - overlap
47
  return chunks
48
 
49
- DOC_CHUNKS = chunk_text(DOC_TEXT)
 
50
 
51
  embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
52
- DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, batch_size=32)
 
53
 
54
  # =====================================================
55
  # LOAD QWEN MODEL (CPU only)
@@ -57,7 +64,7 @@ DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, batch_size=3
57
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
58
  model = AutoModelForCausalLM.from_pretrained(
59
  MODEL_ID,
60
- device_map="cpu", # strictly CPU
61
  torch_dtype=torch.float32,
62
  trust_remote_code=True
63
  )
@@ -67,17 +74,22 @@ model.eval()
67
  # RETRIEVAL WITH CACHE
68
  # =====================================================
69
  @lru_cache(maxsize=256)
70
- def retrieve_context(question: str):
71
  q_emb = embedder.encode([question], normalize_embeddings=True)
72
- scores = np.dot(DOC_EMBEDS, q_emb[0])
73
- top_ids = scores.argsort()[-TOP_K:][::-1]
74
- return "\n\n".join(DOC_CHUNKS[i] for i in top_ids)
 
 
 
 
 
75
 
76
  # =====================================================
77
  # QWEN ANSWER (CPU optimized)
78
  # =====================================================
79
- def answer_question(question: str) -> str:
80
- context = retrieve_context(question)
81
 
82
  messages = [
83
  {
@@ -90,16 +102,10 @@ def answer_question(question: str) -> str:
90
  "'I could not find this information in the document.'"
91
  )
92
  },
93
- {
94
- "role": "user",
95
- "content": f"Context:\n{context}\n\nQuestion:\n{question}"
96
- }
97
  ]
98
 
99
- prompt = tokenizer.apply_chat_template(
100
- messages, tokenize=False, add_generation_prompt=True
101
- )
102
-
103
  inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
104
 
105
  with torch.no_grad():
@@ -118,24 +124,19 @@ def answer_question(question: str) -> str:
118
  # =====================================================
119
  @lru_cache(maxsize=128)
120
  def generate_audio(text: str, language_id: str = "en") -> str:
121
- payload = {
122
- "text": text,
123
- "language_id": language_id,
124
- "mode": "Speak 🗣️"
125
- }
126
-
127
  r = SESSION.post(TTS_API_URL, json=payload, timeout=None)
128
  r.raise_for_status()
129
 
130
  wav_path = f"/tmp/tts_{uuid.uuid4().hex}.wav"
131
 
132
- # Case 1: raw audio
133
  if r.headers.get("content-type", "").startswith("audio"):
134
  with open(wav_path, "wb") as f:
135
  f.write(r.content)
136
  return wav_path
137
 
138
- # Case 2: JSON base64
139
  data = r.json()
140
  audio_b64 = data.get("audio") or data.get("audio_base64") or data.get("wav")
141
  if not audio_b64:
@@ -157,21 +158,23 @@ def run_pipeline(question: str, language_id: str):
157
  if not question.strip():
158
  return "", None
159
 
160
- answer = answer_question(question)
 
161
 
 
162
  try:
163
  audio_path = generate_audio(answer, language_id)
164
  except Exception as e:
165
- print("TTS generation failed:", e)
166
  audio_path = None
167
 
168
  return f"**Bot:** {answer}", audio_path
169
 
170
  # =====================================================
171
- # UI
172
  # =====================================================
173
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
174
- gr.Markdown("# 📄 Qwen CPU Assistant + TTS")
175
 
176
  with gr.Row():
177
  with gr.Column(scale=1):
@@ -197,5 +200,5 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
197
  outputs=[answer_text, answer_audio]
198
  )
199
 
200
- demo.queue() # long-running jobs OK (up to 5 min audio)
201
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
13
  # CONFIG
14
  # =====================================================
15
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
+ DOC_FILE_EN = "general.md"
17
+ DOC_FILE_HI = "general-hi.md"
18
  TTS_API_URL = os.getenv(
19
  "TTS_API_URL",
20
  "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
21
  )
22
  MAX_NEW_TOKENS = 128
23
  TOP_K = 3
 
24
  SESSION = requests.Session()
25
 
26
  # =====================================================
27
+ # LOAD DOCUMENTS
28
  # =====================================================
29
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
30
+ DOC_PATH_EN = os.path.join(BASE_DIR, DOC_FILE_EN)
31
+ DOC_PATH_HI = os.path.join(BASE_DIR, DOC_FILE_HI)
32
+
33
+ for path, name in [(DOC_PATH_EN, DOC_FILE_EN), (DOC_PATH_HI, DOC_FILE_HI)]:
34
+ if not os.path.exists(path):
35
+ raise RuntimeError(f"{name} not found")
36
 
37
+ with open(DOC_PATH_EN, "r", encoding="utf-8", errors="ignore") as f:
38
+ DOC_TEXT_EN = f.read()
39
 
40
+ with open(DOC_PATH_HI, "r", encoding="utf-8", errors="ignore") as f:
41
+ DOC_TEXT_HI = f.read()
42
 
43
  # =====================================================
44
  # CHUNK + EMBED
 
51
  i += chunk_size - overlap
52
  return chunks
53
 
54
+ DOC_CHUNKS_EN = chunk_text(DOC_TEXT_EN)
55
+ DOC_CHUNKS_HI = chunk_text(DOC_TEXT_HI)
56
 
57
  embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
58
+ DOC_EMBEDS_EN = embedder.encode(DOC_CHUNKS_EN, normalize_embeddings=True, batch_size=32)
59
+ DOC_EMBEDS_HI = embedder.encode(DOC_CHUNKS_HI, normalize_embeddings=True, batch_size=32)
60
 
61
  # =====================================================
62
  # LOAD QWEN MODEL (CPU only)
 
64
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
65
  model = AutoModelForCausalLM.from_pretrained(
66
  MODEL_ID,
67
+ device_map="cpu",
68
  torch_dtype=torch.float32,
69
  trust_remote_code=True
70
  )
 
74
  # RETRIEVAL WITH CACHE
75
  # =====================================================
76
  @lru_cache(maxsize=256)
77
+ def retrieve_context(question: str, lang: str):
78
  q_emb = embedder.encode([question], normalize_embeddings=True)
79
+ if lang == "hi":
80
+ scores = np.dot(DOC_EMBEDS_HI, q_emb[0])
81
+ top_ids = scores.argsort()[-TOP_K:][::-1]
82
+ return "\n\n".join(DOC_CHUNKS_HI[i] for i in top_ids)
83
+ else:
84
+ scores = np.dot(DOC_EMBEDS_EN, q_emb[0])
85
+ top_ids = scores.argsort()[-TOP_K:][::-1]
86
+ return "\n\n".join(DOC_CHUNKS_EN[i] for i in top_ids)
87
 
88
  # =====================================================
89
  # QWEN ANSWER (CPU optimized)
90
  # =====================================================
91
+ def answer_question(question: str, lang: str = "en") -> str:
92
+ context = retrieve_context(question, lang)
93
 
94
  messages = [
95
  {
 
102
  "'I could not find this information in the document.'"
103
  )
104
  },
105
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
 
 
 
106
  ]
107
 
108
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
109
  inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
110
 
111
  with torch.no_grad():
 
124
  # =====================================================
125
  @lru_cache(maxsize=128)
126
  def generate_audio(text: str, language_id: str = "en") -> str:
127
+ payload = {"text": text, "language_id": language_id, "mode": "Speak 🗣️"}
 
 
 
 
 
128
  r = SESSION.post(TTS_API_URL, json=payload, timeout=None)
129
  r.raise_for_status()
130
 
131
  wav_path = f"/tmp/tts_{uuid.uuid4().hex}.wav"
132
 
133
+ # raw audio bytes
134
  if r.headers.get("content-type", "").startswith("audio"):
135
  with open(wav_path, "wb") as f:
136
  f.write(r.content)
137
  return wav_path
138
 
139
+ # JSON base64
140
  data = r.json()
141
  audio_b64 = data.get("audio") or data.get("audio_base64") or data.get("wav")
142
  if not audio_b64:
 
158
  if not question.strip():
159
  return "", None
160
 
161
+ # 1️⃣ Answer text
162
+ answer = answer_question(question, language_id)
163
 
164
+ # 2️⃣ TTS
165
  try:
166
  audio_path = generate_audio(answer, language_id)
167
  except Exception as e:
168
+ print("TTS failed:", e)
169
  audio_path = None
170
 
171
  return f"**Bot:** {answer}", audio_path
172
 
173
  # =====================================================
174
+ # GRADIO UI
175
  # =====================================================
176
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
177
+ gr.Markdown("# 📄 Qwen CPU Assistant + Multilingual TTS")
178
 
179
  with gr.Row():
180
  with gr.Column(scale=1):
 
200
  outputs=[answer_text, answer_audio]
201
  )
202
 
203
+ demo.queue() # enable long-running jobs for TTS
204
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)