rahul7star commited on
Commit
a6276b5
·
verified ·
1 Parent(s): ff89053

Update app_qwen_tts.py

Browse files
Files changed (1) hide show
  1. app_qwen_tts.py +101 -107
app_qwen_tts.py CHANGED
@@ -1,56 +1,51 @@
1
  import os
2
  import io
3
  import base64
4
- import time
5
  import torch
6
  import gradio as gr
7
  import numpy as np
8
- import soundfile as sf
9
- import requests
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from sentence_transformers import SentenceTransformer
 
12
 
13
- # =======================
14
- # Configuration
15
- # =======================
16
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
17
  DOC_FILE = "general.md"
 
18
  MAX_NEW_TOKENS = 200
19
  TOP_K = 3
20
- TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts" # your FastAPI TTS endpoint
21
 
22
- # =======================
23
- # Load document
24
- # =======================
25
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
26
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
27
 
28
- if not os.path.exists(DOC_PATH):
29
- raise RuntimeError(f"{DOC_FILE} not found next to app.py")
30
-
31
  with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f:
32
  DOC_TEXT = f.read()
33
 
34
- # =======================
35
- # Chunk document
36
- # =======================
37
  def chunk_text(text, chunk_size=300, overlap=50):
38
  words = text.split()
39
- chunks = []
40
- i = 0
41
  while i < len(words):
42
- chunk = words[i:i + chunk_size]
43
- chunks.append(" ".join(chunk))
44
  i += chunk_size - overlap
45
  return chunks
46
 
47
  DOC_CHUNKS = chunk_text(DOC_TEXT)
 
 
48
 
49
- # =======================
50
- # Load models
51
- # =======================
52
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
53
-
54
  model = AutoModelForCausalLM.from_pretrained(
55
  MODEL_ID,
56
  device_map="auto",
@@ -59,108 +54,107 @@ model = AutoModelForCausalLM.from_pretrained(
59
  )
60
  model.eval()
61
 
62
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
63
- DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True, show_progress_bar=True)
64
-
65
- # =======================
66
- # Utilities
67
- # =======================
68
- def retrieve_context(question, k=TOP_K):
69
  q_emb = embedder.encode([question], normalize_embeddings=True)
70
  scores = np.dot(DOC_EMBEDS, q_emb[0])
71
- top_ids = scores.argsort()[-k:][::-1]
72
- return "\n\n".join([DOC_CHUNKS[i] for i in top_ids])
73
-
74
- def extract_final_answer(text: str) -> str:
75
- text = text.strip()
76
- markers = ["assistant:", "assistant", "answer:", "final answer:"]
77
- for m in markers:
78
- if m.lower() in text.lower():
79
- text = text.lower().split(m, 1)[-1].strip()
80
- lines = [l.strip() for l in text.split("\n") if l.strip()]
81
- return lines[-1] if lines else text
82
-
83
- # =======================
84
- # Qwen inference
85
- # =======================
86
- def answer_question(question: str) -> str:
87
  context = retrieve_context(question)
 
88
  messages = [
89
  {
90
  "role": "system",
91
  "content": (
92
  "You are a strict document-based Q&A assistant.\n"
93
  "Answer ONLY the question.\n"
94
- "Do NOT repeat the context or the question.\n"
95
- "Respond in 1–2 sentences.\n"
96
- "If the answer is not present, say:\n"
97
  "'I could not find this information in the document.'"
98
  )
99
  },
100
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
101
  ]
102
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
103
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
104
 
105
  with torch.no_grad():
106
- output = model.generate(
107
- **inputs,
108
- max_new_tokens=MAX_NEW_TOKENS,
109
- temperature=0.3,
110
- do_sample=True
111
- )
112
 
113
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
114
- return extract_final_answer(decoded)
115
-
116
- # =======================
117
- # TTS via FastAPI
118
- # =======================
119
- def generate_tts_base64(text: str, language_id="en") -> str:
120
- try:
121
- payload = {"text": text, "language_id": language_id, "mode": "Speak 🗣️"}
122
- resp = requests.post(TTS_API_URL, json=payload, timeout=None) # no timeout
123
- resp.raise_for_status()
124
- audio_b64 = resp.json().get("audio", "")
125
- return audio_b64
126
- except Exception as e:
127
- print(f"TTS error: {e}")
128
- return None
129
-
130
- # =======================
131
- # Chat function for Gradio
132
- # =======================
133
- def chat(user_message, history):
134
- if not user_message.strip():
135
- return "", history
136
-
137
- # 1️⃣ Text answer immediately
138
- answer_text = answer_question(user_message)
139
- history.append((user_message, [answer_text, None])) # audio placeholder
140
-
141
- # 2️⃣ Generate audio asynchronously
142
- audio_b64 = generate_tts_base64(answer_text)
143
- if audio_b64:
144
- history[-1][1][1] = f"data:audio/wav;base64,{audio_b64}"
145
-
146
- return "", history
147
-
148
- def reset_chat():
149
- return []
150
-
151
- # =======================
152
- # Build UI
153
- # =======================
154
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
155
- gr.Markdown("## 📄 Qwen Document Assistant + TTS\nText appears instantly; audio plays once ready.")
156
-
157
- chatbot = gr.Chatbot(height=450, type="tuples")
158
- msg = gr.Textbox(placeholder="Ask a question...", lines=2)
159
- send = gr.Button("Send")
160
- clear = gr.Button("Clear")
161
-
162
- send.click(chat, [msg, chatbot], [msg, chatbot])
163
- msg.submit(chat, [msg, chatbot], [msg, chatbot])
164
- clear.click(reset_chat, outputs=chatbot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
1
  import os
2
  import io
3
  import base64
4
+ import requests
5
  import torch
6
  import gradio as gr
7
  import numpy as np
 
 
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  from sentence_transformers import SentenceTransformer
10
+ from scipy.io.wavfile import write as write_wav
11
 
12
+ # =====================================================
13
+ # CONFIG
14
+ # =====================================================
15
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
  DOC_FILE = "general.md"
17
+ TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
18
  MAX_NEW_TOKENS = 200
19
  TOP_K = 3
 
20
 
21
+ # =====================================================
22
+ # LOAD DOCUMENT
23
+ # =====================================================
24
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
25
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
26
 
 
 
 
27
  with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f:
28
  DOC_TEXT = f.read()
29
 
30
+ # =====================================================
31
+ # CHUNK + EMBED
32
+ # =====================================================
33
  def chunk_text(text, chunk_size=300, overlap=50):
34
  words = text.split()
35
+ chunks, i = [], 0
 
36
  while i < len(words):
37
+ chunks.append(" ".join(words[i:i + chunk_size]))
 
38
  i += chunk_size - overlap
39
  return chunks
40
 
41
  DOC_CHUNKS = chunk_text(DOC_TEXT)
42
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
43
+ DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True)
44
 
45
+ # =====================================================
46
+ # LOAD QWEN
47
+ # =====================================================
48
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
49
  model = AutoModelForCausalLM.from_pretrained(
50
  MODEL_ID,
51
  device_map="auto",
 
54
  )
55
  model.eval()
56
 
57
+ # =====================================================
58
+ # RETRIEVAL
59
+ # =====================================================
60
+ def retrieve_context(question):
 
 
 
61
  q_emb = embedder.encode([question], normalize_embeddings=True)
62
  scores = np.dot(DOC_EMBEDS, q_emb[0])
63
+ top_ids = scores.argsort()[-TOP_K:][::-1]
64
+ return "\n\n".join(DOC_CHUNKS[i] for i in top_ids)
65
+
66
+ # =====================================================
67
+ # QWEN ANSWER
68
+ # =====================================================
69
+ def answer_question(question):
 
 
 
 
 
 
 
 
 
70
  context = retrieve_context(question)
71
+
72
  messages = [
73
  {
74
  "role": "system",
75
  "content": (
76
  "You are a strict document-based Q&A assistant.\n"
77
  "Answer ONLY the question.\n"
78
+ "Do NOT repeat context.\n"
79
+ "Respond in 1 sentence.\n"
80
+ "If not found, say:\n"
81
  "'I could not find this information in the document.'"
82
  )
83
  },
84
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
85
  ]
86
+
87
+ prompt = tokenizer.apply_chat_template(
88
+ messages, tokenize=False, add_generation_prompt=True
89
+ )
90
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
91
 
92
  with torch.no_grad():
93
+ output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS)
 
 
 
 
 
94
 
95
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
96
+ return decoded.split("\n")[-1].strip()
97
+
98
+ # =====================================================
99
+ # TTS (BASE64 → WAV)
100
+ # =====================================================
101
+ def generate_audio(text):
102
+ payload = {"text": text, "language_id": "en", "mode": "Speak 🗣️"}
103
+ r = requests.post(TTS_API_URL, json=payload, timeout=None)
104
+ audio_b64 = r.json()["audio"]
105
+
106
+ audio_bytes = base64.b64decode(audio_b64)
107
+ wav_path = "/tmp/output.wav"
108
+
109
+ with open(wav_path, "wb") as f:
110
+ f.write(audio_bytes)
111
+
112
+ return wav_path
113
+
114
+ # =====================================================
115
+ # MAIN HANDLER
116
+ # =====================================================
117
+ def run_pipeline(question):
118
+ if not question.strip():
119
+ return "", None
120
+
121
+ # 1️⃣ TEXT FIRST
122
+ answer = answer_question(question)
123
+
124
+ # 2️⃣ AUDIO (SLOW, NO TIMEOUT)
125
+ audio_path = generate_audio(answer)
126
+
127
+ return answer, audio_path
128
+
129
+ # =====================================================
130
+ # UI
131
+ # =====================================================
 
 
 
 
132
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
+ gr.Markdown("## 🤖 Document Q&A with Voice")
134
+
135
+ with gr.Row():
136
+ with gr.Column(scale=1):
137
+ user_input = gr.Textbox(
138
+ label="Your Question",
139
+ placeholder="Who is CEO of OhamLab?",
140
+ lines=4
141
+ )
142
+ ask_btn = gr.Button("Ask")
143
+
144
+ with gr.Column(scale=1):
145
+ answer_text = gr.Markdown(
146
+ label="Assistant Answer",
147
+ value="**Bot:** _Waiting for question..._"
148
+ )
149
+ answer_audio = gr.Audio(
150
+ label="Assistant Voice",
151
+ type="filepath"
152
+ )
153
+
154
+ ask_btn.click(
155
+ fn=run_pipeline,
156
+ inputs=user_input,
157
+ outputs=[answer_text, answer_audio]
158
+ )
159
 
160
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)