rahul7star commited on
Commit
bb5fd6a
·
verified ·
1 Parent(s): 5deb4cf

Create app_qwen_tts_fast.py

Browse files
Files changed (1) hide show
  1. app_qwen_tts_fast.py +195 -0
app_qwen_tts_fast.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+ import soundfile as sf
7
+ from functools import lru_cache
8
+
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from sentence_transformers import SentenceTransformer
11
+
12
+ # =========================================================
13
+ # CONFIG
14
+ # =========================================================
15
+ MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
+ DOC_FILE = "general.md"
17
+
18
+ MAX_NEW_TOKENS = 200
19
+ TOP_K = 3
20
+ MAX_TTS_CHARS = 200 # 🔥 BIG SPEED WIN
21
+
22
+ TTS_API_URL = os.getenv(
23
+ "TTS_API_URL",
24
+ "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts"
25
+ )
26
+
27
+ SESSION = requests.Session() # 🔥 reuse HTTP connection
28
+
29
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
30
+ DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
31
+
32
+ # =========================================================
33
+ # LOAD MODELS
34
+ # =========================================================
35
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
36
+
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ MODEL_ID,
39
+ device_map="auto",
40
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
41
+ trust_remote_code=True
42
+ )
43
+ model.eval()
44
+
45
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
46
+
47
+ # =========================================================
48
+ # LOAD DOCUMENT
49
+ # =========================================================
50
+ def chunk_text(text, chunk_size=300, overlap=50):
51
+ words = text.split()
52
+ chunks = []
53
+ i = 0
54
+ while i < len(words):
55
+ chunks.append(" ".join(words[i:i + chunk_size]))
56
+ i += chunk_size - overlap
57
+ return chunks
58
+
59
+ with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f:
60
+ DOC_TEXT = f.read()
61
+
62
+ DOC_CHUNKS = chunk_text(DOC_TEXT)
63
+ DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True)
64
+
65
+ # =========================================================
66
+ # RETRIEVAL
67
+ # =========================================================
68
+ def retrieve_context(question, k=TOP_K):
69
+ q_emb = embedder.encode([question], normalize_embeddings=True)
70
+ scores = np.dot(DOC_EMBEDS, q_emb[0])
71
+ ids = scores.argsort()[-k:][::-1]
72
+ return "\n\n".join(DOC_CHUNKS[i] for i in ids)
73
+
74
+ # =========================================================
75
+ # QWEN (CACHED)
76
+ # =========================================================
77
+ @lru_cache(maxsize=128)
78
+ def cached_answer(question: str) -> str:
79
+ context = retrieve_context(question)
80
+
81
+ messages = [
82
+ {
83
+ "role": "system",
84
+ "content": (
85
+ "You are a strict document-based Q&A assistant.\n"
86
+ "Answer ONLY the question.\n"
87
+ "Do NOT repeat context or question.\n"
88
+ "Respond in 1 short sentence.\n"
89
+ "If not found say:\n"
90
+ "'I could not find this information in the document.'"
91
+ )
92
+ },
93
+ {
94
+ "role": "user",
95
+ "content": f"Context:\n{context}\n\nQuestion:\n{question}"
96
+ }
97
+ ]
98
+
99
+ prompt = tokenizer.apply_chat_template(
100
+ messages, tokenize=False, add_generation_prompt=True
101
+ )
102
+
103
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
104
+
105
+ with torch.no_grad():
106
+ output = model.generate(
107
+ **inputs,
108
+ max_new_tokens=MAX_NEW_TOKENS,
109
+ temperature=0.3,
110
+ do_sample=True
111
+ )
112
+
113
+ text = tokenizer.decode(output[0], skip_special_tokens=True)
114
+ return text.strip().split("\n")[-1]
115
+
116
+ # =========================================================
117
+ # TTS (CACHED)
118
+ # =========================================================
119
+ @lru_cache(maxsize=128)
120
+ def cached_tts(text: str) -> str:
121
+ payload = {
122
+ "text": text[:MAX_TTS_CHARS],
123
+ "language_id": "en",
124
+ "mode": "Speak 🗣️",
125
+ "exaggeration": 0.5,
126
+ "temperature": 0.8,
127
+ "cfg_weight": 0.5
128
+ }
129
+
130
+ r = SESSION.post(TTS_API_URL, json=payload)
131
+ r.raise_for_status()
132
+
133
+ audio_bytes = r.content # RAW WAV BYTES
134
+ audio_path = f"/tmp/{abs(hash(text))}.wav"
135
+
136
+ with open(audio_path, "wb") as f:
137
+ f.write(audio_bytes)
138
+
139
+ return audio_path
140
+
141
+ # =========================================================
142
+ # PIPELINE
143
+ # =========================================================
144
+ def run_pipeline(question):
145
+ if not question.strip():
146
+ return "", None
147
+
148
+ # 1️⃣ TEXT (FAST)
149
+ answer = cached_answer(question)
150
+
151
+ # 2️⃣ AUDIO (CAN TAKE TIME)
152
+ audio_path = cached_tts(answer)
153
+
154
+ return answer, audio_path
155
+
156
+ # =========================================================
157
+ # UI
158
+ # =========================================================
159
+ def build_ui():
160
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
161
+ gr.Markdown("## 🤖 OhamLab AI Assistant with Voice")
162
+
163
+ with gr.Row():
164
+ question = gr.Textbox(
165
+ label="Your Question",
166
+ placeholder="Ask something about the document...",
167
+ lines=2
168
+ )
169
+
170
+ ask = gr.Button("🚀 Ask")
171
+
172
+ with gr.Row():
173
+ answer_box = gr.Markdown(label="Answer")
174
+ with gr.Row():
175
+ audio_box = gr.Audio(label="Voice Response", autoplay=True)
176
+
177
+ ask.click(
178
+ fn=run_pipeline,
179
+ inputs=question,
180
+ outputs=[answer_box, audio_box]
181
+ )
182
+
183
+ demo.launch(
184
+ server_name="0.0.0.0",
185
+ server_port=7860,
186
+ share=False,
187
+ show_api=False
188
+ )
189
+
190
+ # =========================================================
191
+ # MAIN
192
+ # =========================================================
193
+ if __name__ == "__main__":
194
+ print("✅ Qwen + TTS Assistant Ready")
195
+ build_ui()