import gradio as gr from faster_whisper import WhisperModel from transformers import AutoTokenizer, AutoModelForCausalLM import torch import requests import base64 import tempfile import os import logging import time from datetime import datetime from concurrent.futures import ThreadPoolExecutor from html.parser import HTMLParser # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Initialize models logger.info("Loading Whisper model...") whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8") logger.info("Loading Qwen 2.5 0.5B-Instruct (FASTEST)...") model_name = "Qwen/Qwen2.5-0.5B-Instruct" # SWITCHED BACK to 0.5B for speed tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, device_map="cpu", low_cpu_mem_usage=True ) logger.info("All models loaded!") TAVILY_API_KEY = os.getenv('TAVILY_API_KEY', '') BRAVE_API_KEY = os.getenv('BRAVE_API_KEY', '') def search_tavily(query): logger.info("[TAVILY] Starting...") if not TAVILY_API_KEY: return None try: response = requests.post( 'https://api.tavily.com/search', json={'api_key': TAVILY_API_KEY, 'query': query, 'max_results': 3}, timeout=2 # REDUCED timeout ) if response.status_code == 200: data = response.json() results = data.get('results', []) context = "" for i, result in enumerate(results[:3], 1): context += f"\n[{i}] {result.get('title', '')}\n{result.get('content', '')}\n" logger.info(f"[TAVILY] ✓") return context except: pass return None def search_brave(query): logger.info("[BRAVE] Starting...") if not BRAVE_API_KEY: return None try: response = requests.get( 'https://api.search.brave.com/res/v1/web/search', params={'q': query, 'count': 3}, headers={'X-Subscription-Token': BRAVE_API_KEY}, timeout=2 ) if response.status_code == 200: data = response.json() results = data.get('web', {}).get('results', []) context = "" for i, result in enumerate(results[:3], 1): context += f"\n[{i}] {result.get('title', '')}\n{result.get('description', '')}\n" logger.info(f"[BRAVE] ✓") return context except: pass return None def search_searx(query): logger.info("[SEARX] Starting...") for instance in ['https://searx.be/search', 'https://searx.work/search']: try: response = requests.get( instance, params={'q': query, 'format': 'json', 'categories': 'general'}, timeout=2 ) if response.status_code == 200: data = response.json() results = data.get('results', []) context = "" for i, result in enumerate(results[:3], 1): context += f"\n[{i}] {result.get('title', '')}\n{result.get('content', '')}\n" logger.info(f"[SEARX] ✓") return context except: continue return None def search_duckduckgo_html(query): logger.info("[DDG] Starting...") try: response = requests.get( 'https://html.duckduckgo.com/html/', params={'q': query}, headers={'User-Agent': 'Mozilla/5.0'}, timeout=2 ) if response.status_code == 200: class DDGParser(HTMLParser): def __init__(self): super().__init__() self.results = [] self.in_result = False self.current_text = "" def handle_starttag(self, tag, attrs): if tag == 'a' and any(k == 'class' and 'result__a' in v for k, v in attrs): self.in_result = True def handle_data(self, data): if self.in_result: self.current_text += data.strip() def handle_endtag(self, tag): if tag == 'a' and self.in_result: self.results.append(self.current_text) self.current_text = "" self.in_result = False parser = DDGParser() parser.feed(response.text) context = "" for i, result in enumerate(parser.results[:3], 1): context += f"\n[{i}] {result}\n" if context: logger.info(f"[DDG] ✓") return context except: pass return None def search_parallel(query): logger.info("[SEARCH] Parallel start") with ThreadPoolExecutor(max_workers=4) as executor: futures = { executor.submit(search_tavily, query): "Tavily", executor.submit(search_brave, query): "Brave", executor.submit(search_searx, query): "Searx", executor.submit(search_duckduckgo_html, query): "DDG" } results = {} for future in futures: engine = futures[future] try: result = future.result(timeout=3) if result: results[engine] = result except: pass for engine in ["Tavily", "Brave", "Searx", "DDG"]: if engine in results: logger.info(f"[SEARCH] Using {engine}") return results[engine], engine return "No search results available.", "None" def transcribe_audio_base64(audio_base64): logger.info("[STT] Request") try: audio_bytes = base64.b64decode(audio_base64) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: temp_audio.write(audio_bytes) temp_path = temp_audio.name segments, _ = whisper_model.transcribe(temp_path, language="en", beam_size=1) transcription = " ".join([seg.text for seg in segments]) os.unlink(temp_path) logger.info(f"[STT] ✓") return {"text": transcription.strip()} except Exception as e: return {"error": str(e)} def generate_answer(text_input): logger.info(f"[AI] Q: {text_input}") try: if not text_input or not text_input.strip(): return "No input provided" current_date = datetime.now().strftime("%B %d, %Y") search_start = time.time() search_results, search_engine = search_parallel(text_input) search_time = time.time() - search_start logger.info(f"[AI] Search: {search_time:.2f}s") # IMPROVED PROMPT - Structured multi-point answers messages = [ { "role": "system", "content": f"""Today is {current_date}. You are a concise assistant. When answering: - If question asks about multiple things, list each with a one-line description - Use bullet points for multiple items - Keep total answer to 80-100 words - Answer ONLY from search results""" }, { "role": "user", "content": f"""Search Results: {search_results} Question: {text_input} Answer (80-100 words, use bullets if multiple topics):""" } ] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) gen_start = time.time() inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1200) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=100, # REDUCED from 150 temperature=0.7, # INCREASED for faster sampling do_sample=True, top_p=0.9, top_k=50, # ADDED for speed repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id ) gen_time = time.time() - gen_start logger.info(f"[AI] Gen: {gen_time:.2f}s") answer = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip() answer_with_source = f"{answer}\n\n**Source:** {search_engine}" logger.info(f"[AI] ✓") return answer_with_source except Exception as e: logger.error(f"[AI] Error: {str(e)}") return f"Error: {str(e)}" def process_audio(audio_path, question_text): start_time = time.time() logger.info("="*40) if audio_path: try: segments, _ = whisper_model.transcribe(audio_path, language="en", beam_size=1) question = " ".join([seg.text for seg in segments]) except Exception as e: return f"❌ Error: {str(e)}", 0.0 else: question = question_text if not question or not question.strip(): return "❌ No input", 0.0 answer = generate_answer(question) total_time = time.time() - start_time time_emoji = "🟢" if total_time < 3.0 else "🟡" if total_time < 5.0 else "🔴" timing = f"\n\n{time_emoji} **Time:** {total_time:.2f}s" logger.info(f"[TOTAL] {total_time:.2f}s") logger.info("="*40) return answer + timing, total_time def audio_handler(audio_path): return process_audio(audio_path, None) def text_handler(text_input): return process_audio(None, text_input) # Gradio UI with gr.Blocks(title="Fast Q&A", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # ⚡ Ultra-Fast Q&A System **Qwen 0.5B + Parallel Search** (Optimized for <3s response) """) with gr.Tab("🎙️ Audio"): with gr.Row(): with gr.Column(): audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath") audio_submit = gr.Button("🚀 Submit", variant="primary", size="lg") with gr.Column(): audio_output = gr.Textbox(label="Answer", lines=8, show_copy_button=True) audio_time = gr.Number(label="Time (s)", precision=2) audio_submit.click(fn=audio_handler, inputs=[audio_input], outputs=[audio_output, audio_time], api_name="audio_query") with gr.Tab("✍️ Text"): with gr.Row(): with gr.Column(): text_input = gr.Textbox(label="Question", placeholder="Ask anything...", lines=3) text_submit = gr.Button("🚀 Submit", variant="primary", size="lg") with gr.Column(): text_output = gr.Textbox(label="Answer", lines=8, show_copy_button=True) text_time = gr.Number(label="Time (s)", precision=2) text_submit.click(fn=text_handler, inputs=[text_input], outputs=[text_output, text_time], api_name="text_query") gr.Examples( examples=[ ["What are the top 3 news stories today?"], ["Is internet shut down in Bareilly?"], ["Who won 2024 US election?"] ], inputs=text_input ) with gr.Tab("🔌 API"): gr.Markdown(""" **Endpoints:** - STT: `/call/transcribe_stt` → Path: `data[0].text` - AI: `/call/answer_ai` → Path: `data[0]` """) with gr.Row(visible=False): stt_in = gr.Textbox() stt_out = gr.JSON() ai_in = gr.Textbox() ai_out = gr.Textbox() gr.Button("STT", visible=False).click(fn=transcribe_audio_base64, inputs=[stt_in], outputs=[stt_out], api_name="transcribe_stt") gr.Button("AI", visible=False).click(fn=generate_answer, inputs=[ai_in], outputs=[ai_out], api_name="answer_ai") gr.Markdown(""" **Speed:** Qwen 0.5B (1-2s) + Parallel search (1s) = **2-3s total** 🟢 < 3s | 🟡 3-5s | 🔴 > 5s """) if __name__ == "__main__": demo.queue(max_size=5) demo.launch()