rahul7star commited on
Commit
fbf4f7f
·
verified ·
1 Parent(s): 43b2aa5

Update app_qwen_tts_fast.py

Browse files
Files changed (1) hide show
  1. app_qwen_tts_fast.py +107 -109
app_qwen_tts_fast.py CHANGED
@@ -1,81 +1,82 @@
1
  import os
 
 
2
  import requests
3
  import torch
4
  import gradio as gr
5
  import numpy as np
6
- import soundfile as sf
7
  from functools import lru_cache
8
-
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from sentence_transformers import SentenceTransformer
11
 
12
- # =========================================================
13
  # CONFIG
14
- # =========================================================
15
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
  DOC_FILE = "general.md"
17
-
18
- MAX_NEW_TOKENS = 200
19
- TOP_K = 3
20
- MAX_TTS_CHARS = 200 # 🔥 BIG SPEED WIN
21
-
22
  TTS_API_URL = os.getenv(
23
  "TTS_API_URL",
24
  "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
25
  )
26
 
27
- SESSION = requests.Session() # 🔥 reuse HTTP connection
 
 
 
28
 
 
 
 
29
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
30
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
31
 
32
- # =========================================================
33
- # LOAD MODELS
34
- # =========================================================
35
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
36
-
37
- model = AutoModelForCausalLM.from_pretrained(
38
- MODEL_ID,
39
- device_map="auto",
40
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
41
- trust_remote_code=True
42
- )
43
- model.eval()
44
-
45
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
46
 
47
- # =========================================================
48
- # LOAD DOCUMENT
49
- # =========================================================
50
  def chunk_text(text, chunk_size=300, overlap=50):
51
  words = text.split()
52
- chunks = []
53
- i = 0
54
  while i < len(words):
55
  chunks.append(" ".join(words[i:i + chunk_size]))
56
  i += chunk_size - overlap
57
  return chunks
58
 
59
- with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f:
60
- DOC_TEXT = f.read()
61
-
62
  DOC_CHUNKS = chunk_text(DOC_TEXT)
63
- DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True)
64
 
65
- # =========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # RETRIEVAL
67
- # =========================================================
68
- def retrieve_context(question, k=TOP_K):
 
69
  q_emb = embedder.encode([question], normalize_embeddings=True)
70
  scores = np.dot(DOC_EMBEDS, q_emb[0])
71
- ids = scores.argsort()[-k:][::-1]
72
- return "\n\n".join(DOC_CHUNKS[i] for i in ids)
73
 
74
- # =========================================================
75
- # QWEN (CACHED)
76
- # =========================================================
77
- @lru_cache(maxsize=128)
78
- def cached_answer(question: str) -> str:
79
  context = retrieve_context(question)
80
 
81
  messages = [
@@ -84,9 +85,8 @@ def cached_answer(question: str) -> str:
84
  "content": (
85
  "You are a strict document-based Q&A assistant.\n"
86
  "Answer ONLY the question.\n"
87
- "Do NOT repeat context or question.\n"
88
  "Respond in 1 short sentence.\n"
89
- "If not found say:\n"
90
  "'I could not find this information in the document.'"
91
  )
92
  },
@@ -106,99 +106,97 @@ def cached_answer(question: str) -> str:
106
  output = model.generate(
107
  **inputs,
108
  max_new_tokens=MAX_NEW_TOKENS,
109
- temperature=0.3,
110
- do_sample=True
111
  )
112
 
113
- text = tokenizer.decode(output[0], skip_special_tokens=True)
114
- return text.strip().split("\n")[-1]
115
-
116
- # =========================================================
117
- # TTS (CACHED)
118
- # =========================================================
119
- import base64
120
 
 
 
 
121
  @lru_cache(maxsize=128)
122
- def cached_tts(text: str) -> str:
123
  payload = {
124
- "text": text[:MAX_TTS_CHARS],
125
  "language_id": "en",
126
- "mode": "Speak 🗣️",
127
- "exaggeration": 0.5,
128
- "temperature": 0.8,
129
- "cfg_weight": 0.5
130
  }
131
 
132
- r = SESSION.post(TTS_API_URL, json=payload)
133
  r.raise_for_status()
134
 
 
 
 
 
 
 
 
 
 
 
135
  data = r.json()
 
 
 
 
 
136
 
137
- if "audio" not in data:
138
- raise RuntimeError("TTS API returned no audio field")
139
 
140
- audio_b64 = data["audio"]
141
  audio_bytes = base64.b64decode(audio_b64)
142
 
143
- audio_path = f"/tmp/tts_{abs(hash(text))}.wav"
144
- with open(audio_path, "wb") as f:
145
  f.write(audio_bytes)
146
 
147
- return audio_path
 
148
 
 
149
 
150
- # =========================================================
151
- # PIPELINE
152
- # =========================================================
153
  def run_pipeline(question):
154
  if not question.strip():
155
  return "", None
156
 
157
- # 1️⃣ TEXT (FAST)
158
- answer = cached_answer(question)
159
-
160
- # 2️⃣ AUDIO (CAN TAKE TIME)
161
- audio_path = cached_tts(answer)
162
 
163
- return answer, audio_path
164
 
165
- # =========================================================
166
  # UI
167
- # =========================================================
168
- def build_ui():
169
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
170
- gr.Markdown("## 🤖 OhamLab AI Assistant with Voice")
171
 
172
- with gr.Row():
173
- question = gr.Textbox(
 
174
  label="Your Question",
175
- placeholder="Ask something about the document...",
176
- lines=2
177
  )
 
178
 
179
- ask = gr.Button("🚀 Ask")
 
 
180
 
181
- with gr.Row():
182
- answer_box = gr.Markdown(label="Answer")
183
- with gr.Row():
184
- audio_box = gr.Audio(label="Voice Response", autoplay=True)
185
-
186
- ask.click(
187
- fn=run_pipeline,
188
- inputs=question,
189
- outputs=[answer_box, audio_box]
190
- )
191
-
192
- demo.launch(
193
- server_name="0.0.0.0",
194
- server_port=7860,
195
- share=False,
196
- show_api=False
197
- )
198
 
199
- # =========================================================
200
- # MAIN
201
- # =========================================================
202
- if __name__ == "__main__":
203
- print("✅ Qwen + TTS Assistant Ready")
204
- build_ui()
 
1
  import os
2
+ import base64
3
+ import uuid
4
  import requests
5
  import torch
6
  import gradio as gr
7
  import numpy as np
 
8
  from functools import lru_cache
 
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from sentence_transformers import SentenceTransformer
11
 
12
+ # =====================================================
13
  # CONFIG
14
+ # =====================================================
15
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
  DOC_FILE = "general.md"
 
 
 
 
 
17
  TTS_API_URL = os.getenv(
18
  "TTS_API_URL",
19
  "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
20
  )
21
 
22
+ MAX_NEW_TOKENS = 128
23
+ TOP_K = 3
24
+
25
+ SESSION = requests.Session()
26
 
27
+ # =====================================================
28
+ # LOAD DOCUMENT
29
+ # =====================================================
30
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
31
  DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
32
 
33
+ with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f:
34
+ DOC_TEXT = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # =====================================================
37
+ # CHUNK + EMBED
38
+ # =====================================================
39
  def chunk_text(text, chunk_size=300, overlap=50):
40
  words = text.split()
41
+ chunks, i = [], 0
 
42
  while i < len(words):
43
  chunks.append(" ".join(words[i:i + chunk_size]))
44
  i += chunk_size - overlap
45
  return chunks
46
 
 
 
 
47
  DOC_CHUNKS = chunk_text(DOC_TEXT)
 
48
 
49
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
50
+ DOC_EMBEDS = embedder.encode(
51
+ DOC_CHUNKS, normalize_embeddings=True, batch_size=32
52
+ )
53
+
54
+ # =====================================================
55
+ # LOAD QWEN (FAST SETTINGS)
56
+ # =====================================================
57
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ MODEL_ID,
60
+ device_map="auto",
61
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
62
+ trust_remote_code=True
63
+ )
64
+ model.eval()
65
+
66
+ # =====================================================
67
  # RETRIEVAL
68
+ # =====================================================
69
+ @lru_cache(maxsize=256)
70
+ def retrieve_context(question: str):
71
  q_emb = embedder.encode([question], normalize_embeddings=True)
72
  scores = np.dot(DOC_EMBEDS, q_emb[0])
73
+ top_ids = scores.argsort()[-TOP_K:][::-1]
74
+ return "\n\n".join(DOC_CHUNKS[i] for i in top_ids)
75
 
76
+ # =====================================================
77
+ # QWEN ANSWER (FAST)
78
+ # =====================================================
79
+ def answer_question(question: str) -> str:
 
80
  context = retrieve_context(question)
81
 
82
  messages = [
 
85
  "content": (
86
  "You are a strict document-based Q&A assistant.\n"
87
  "Answer ONLY the question.\n"
 
88
  "Respond in 1 short sentence.\n"
89
+ "If not found, say:\n"
90
  "'I could not find this information in the document.'"
91
  )
92
  },
 
106
  output = model.generate(
107
  **inputs,
108
  max_new_tokens=MAX_NEW_TOKENS,
109
+ do_sample=False,
110
+ use_cache=True
111
  )
112
 
113
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
114
+ return decoded.split("\n")[-1].strip()
 
 
 
 
 
115
 
116
+ # =====================================================
117
+ # TTS (FAST + SAFE)
118
+ # =====================================================
119
  @lru_cache(maxsize=128)
120
+ def generate_audio(text: str) -> str:
121
  payload = {
122
+ "text": text,
123
  "language_id": "en",
124
+ "mode": "Speak 🗣️"
 
 
 
125
  }
126
 
127
+ r = SESSION.post(TTS_API_URL, json=payload, timeout=None)
128
  r.raise_for_status()
129
 
130
+ # Unique output path
131
+ wav_path = f"/tmp/tts_{uuid.uuid4().hex}.wav"
132
+
133
+ # Case 1: raw audio
134
+ if r.headers.get("content-type", "").startswith("audio"):
135
+ with open(wav_path, "wb") as f:
136
+ f.write(r.content)
137
+ return wav_path
138
+
139
+ # Case 2: JSON base64
140
  data = r.json()
141
+ audio_b64 = (
142
+ data.get("audio")
143
+ or data.get("audio_base64")
144
+ or data.get("wav")
145
+ )
146
 
147
+ if not audio_b64:
148
+ raise RuntimeError(f"TTS API returned no audio field: {data}")
149
 
 
150
  audio_bytes = base64.b64decode(audio_b64)
151
 
152
+ with open(wav_path, "wb") as f:
 
153
  f.write(audio_bytes)
154
 
155
+ if os.path.getsize(wav_path) < 1000:
156
+ raise RuntimeError("Generated audio file is too small")
157
 
158
+ return wav_path
159
 
160
+ # =====================================================
161
+ # MAIN PIPELINE
162
+ # =====================================================
163
  def run_pipeline(question):
164
  if not question.strip():
165
  return "", None
166
 
167
+ answer = answer_question(question)
168
+ audio_path = generate_audio(answer)
 
 
 
169
 
170
+ return f"**Bot:** {answer}", audio_path
171
 
172
+ # =====================================================
173
  # UI
174
+ # =====================================================
175
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
176
+ gr.Markdown("## 📘 Document Q&A with Voice")
 
177
 
178
+ with gr.Row():
179
+ with gr.Column():
180
+ user_input = gr.Textbox(
181
  label="Your Question",
182
+ placeholder="Who is CEO of OhamLab?",
183
+ lines=3
184
  )
185
+ ask_btn = gr.Button("Ask")
186
 
187
+ with gr.Column():
188
+ answer_text = gr.Markdown()
189
+ answer_audio = gr.Audio(type="filepath")
190
 
191
+ ask_btn.click(
192
+ fn=run_pipeline,
193
+ inputs=user_input,
194
+ outputs=[answer_text, answer_audio]
195
+ )
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ demo.launch(
198
+ server_name="0.0.0.0",
199
+ server_port=7860,
200
+ share=False,
201
+ queue=True
202
+ )