basic_app / app.py
ArchCoder's picture
Update app.py
4b20d59 verified
raw
history blame
12.5 kB
import gradio as gr
from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import requests
import base64
import tempfile
import os
import logging
import time
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from html.parser import HTMLParser
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Initialize models
logger.info("Loading Whisper model...")
whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8")
logger.info("Loading Qwen 2.5 0.5B-Instruct (FASTEST)...")
model_name = "Qwen/Qwen2.5-0.5B-Instruct" # SWITCHED BACK to 0.5B for speed
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True
)
logger.info("All models loaded!")
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY', '')
BRAVE_API_KEY = os.getenv('BRAVE_API_KEY', '')
def search_tavily(query):
logger.info("[TAVILY] Starting...")
if not TAVILY_API_KEY:
return None
try:
response = requests.post(
'https://api.tavily.com/search',
json={'api_key': TAVILY_API_KEY, 'query': query, 'max_results': 3},
timeout=2 # REDUCED timeout
)
if response.status_code == 200:
data = response.json()
results = data.get('results', [])
context = ""
for i, result in enumerate(results[:3], 1):
context += f"\n[{i}] {result.get('title', '')}\n{result.get('content', '')}\n"
logger.info(f"[TAVILY] ✓")
return context
except:
pass
return None
def search_brave(query):
logger.info("[BRAVE] Starting...")
if not BRAVE_API_KEY:
return None
try:
response = requests.get(
'https://api.search.brave.com/res/v1/web/search',
params={'q': query, 'count': 3},
headers={'X-Subscription-Token': BRAVE_API_KEY},
timeout=2
)
if response.status_code == 200:
data = response.json()
results = data.get('web', {}).get('results', [])
context = ""
for i, result in enumerate(results[:3], 1):
context += f"\n[{i}] {result.get('title', '')}\n{result.get('description', '')}\n"
logger.info(f"[BRAVE] ✓")
return context
except:
pass
return None
def search_searx(query):
logger.info("[SEARX] Starting...")
for instance in ['https://searx.be/search', 'https://searx.work/search']:
try:
response = requests.get(
instance,
params={'q': query, 'format': 'json', 'categories': 'general'},
timeout=2
)
if response.status_code == 200:
data = response.json()
results = data.get('results', [])
context = ""
for i, result in enumerate(results[:3], 1):
context += f"\n[{i}] {result.get('title', '')}\n{result.get('content', '')}\n"
logger.info(f"[SEARX] ✓")
return context
except:
continue
return None
def search_duckduckgo_html(query):
logger.info("[DDG] Starting...")
try:
response = requests.get(
'https://html.duckduckgo.com/html/',
params={'q': query},
headers={'User-Agent': 'Mozilla/5.0'},
timeout=2
)
if response.status_code == 200:
class DDGParser(HTMLParser):
def __init__(self):
super().__init__()
self.results = []
self.in_result = False
self.current_text = ""
def handle_starttag(self, tag, attrs):
if tag == 'a' and any(k == 'class' and 'result__a' in v for k, v in attrs):
self.in_result = True
def handle_data(self, data):
if self.in_result:
self.current_text += data.strip()
def handle_endtag(self, tag):
if tag == 'a' and self.in_result:
self.results.append(self.current_text)
self.current_text = ""
self.in_result = False
parser = DDGParser()
parser.feed(response.text)
context = ""
for i, result in enumerate(parser.results[:3], 1):
context += f"\n[{i}] {result}\n"
if context:
logger.info(f"[DDG] ✓")
return context
except:
pass
return None
def search_parallel(query):
logger.info("[SEARCH] Parallel start")
with ThreadPoolExecutor(max_workers=4) as executor:
futures = {
executor.submit(search_tavily, query): "Tavily",
executor.submit(search_brave, query): "Brave",
executor.submit(search_searx, query): "Searx",
executor.submit(search_duckduckgo_html, query): "DDG"
}
results = {}
for future in futures:
engine = futures[future]
try:
result = future.result(timeout=3)
if result:
results[engine] = result
except:
pass
for engine in ["Tavily", "Brave", "Searx", "DDG"]:
if engine in results:
logger.info(f"[SEARCH] Using {engine}")
return results[engine], engine
return "No search results available.", "None"
def transcribe_audio_base64(audio_base64):
logger.info("[STT] Request")
try:
audio_bytes = base64.b64decode(audio_base64)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio.write(audio_bytes)
temp_path = temp_audio.name
segments, _ = whisper_model.transcribe(temp_path, language="en", beam_size=1)
transcription = " ".join([seg.text for seg in segments])
os.unlink(temp_path)
logger.info(f"[STT] ✓")
return {"text": transcription.strip()}
except Exception as e:
return {"error": str(e)}
def generate_answer(text_input):
logger.info(f"[AI] Q: {text_input}")
try:
if not text_input or not text_input.strip():
return "No input provided"
current_date = datetime.now().strftime("%B %d, %Y")
search_start = time.time()
search_results, search_engine = search_parallel(text_input)
search_time = time.time() - search_start
logger.info(f"[AI] Search: {search_time:.2f}s")
# IMPROVED PROMPT - Structured multi-point answers
messages = [
{
"role": "system",
"content": f"""Today is {current_date}. You are a concise assistant.
When answering:
- If question asks about multiple things, list each with a one-line description
- Use bullet points for multiple items
- Keep total answer to 80-100 words
- Answer ONLY from search results"""
},
{
"role": "user",
"content": f"""Search Results:
{search_results}
Question: {text_input}
Answer (80-100 words, use bullets if multiple topics):"""
}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
gen_start = time.time()
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1200)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100, # REDUCED from 150
temperature=0.7, # INCREASED for faster sampling
do_sample=True,
top_p=0.9,
top_k=50, # ADDED for speed
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id
)
gen_time = time.time() - gen_start
logger.info(f"[AI] Gen: {gen_time:.2f}s")
answer = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
answer_with_source = f"{answer}\n\n**Source:** {search_engine}"
logger.info(f"[AI] ✓")
return answer_with_source
except Exception as e:
logger.error(f"[AI] Error: {str(e)}")
return f"Error: {str(e)}"
def process_audio(audio_path, question_text):
start_time = time.time()
logger.info("="*40)
if audio_path:
try:
segments, _ = whisper_model.transcribe(audio_path, language="en", beam_size=1)
question = " ".join([seg.text for seg in segments])
except Exception as e:
return f"❌ Error: {str(e)}", 0.0
else:
question = question_text
if not question or not question.strip():
return "❌ No input", 0.0
answer = generate_answer(question)
total_time = time.time() - start_time
time_emoji = "🟢" if total_time < 3.0 else "🟡" if total_time < 5.0 else "🔴"
timing = f"\n\n{time_emoji} **Time:** {total_time:.2f}s"
logger.info(f"[TOTAL] {total_time:.2f}s")
logger.info("="*40)
return answer + timing, total_time
def audio_handler(audio_path):
return process_audio(audio_path, None)
def text_handler(text_input):
return process_audio(None, text_input)
# Gradio UI
with gr.Blocks(title="Fast Q&A", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ⚡ Ultra-Fast Q&A System
**Qwen 0.5B + Parallel Search** (Optimized for <3s response)
""")
with gr.Tab("🎙️ Audio"):
with gr.Row():
with gr.Column():
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath")
audio_submit = gr.Button("🚀 Submit", variant="primary", size="lg")
with gr.Column():
audio_output = gr.Textbox(label="Answer", lines=8, show_copy_button=True)
audio_time = gr.Number(label="Time (s)", precision=2)
audio_submit.click(fn=audio_handler, inputs=[audio_input], outputs=[audio_output, audio_time], api_name="audio_query")
with gr.Tab("✍️ Text"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Question", placeholder="Ask anything...", lines=3)
text_submit = gr.Button("🚀 Submit", variant="primary", size="lg")
with gr.Column():
text_output = gr.Textbox(label="Answer", lines=8, show_copy_button=True)
text_time = gr.Number(label="Time (s)", precision=2)
text_submit.click(fn=text_handler, inputs=[text_input], outputs=[text_output, text_time], api_name="text_query")
gr.Examples(
examples=[
["What are the top 3 news stories today?"],
["Is internet shut down in Bareilly?"],
["Who won 2024 US election?"]
],
inputs=text_input
)
with gr.Tab("🔌 API"):
gr.Markdown("""
**Endpoints:**
- STT: `/call/transcribe_stt` → Path: `data[0].text`
- AI: `/call/answer_ai` → Path: `data[0]`
""")
with gr.Row(visible=False):
stt_in = gr.Textbox()
stt_out = gr.JSON()
ai_in = gr.Textbox()
ai_out = gr.Textbox()
gr.Button("STT", visible=False).click(fn=transcribe_audio_base64, inputs=[stt_in], outputs=[stt_out], api_name="transcribe_stt")
gr.Button("AI", visible=False).click(fn=generate_answer, inputs=[ai_in], outputs=[ai_out], api_name="answer_ai")
gr.Markdown("""
**Speed:** Qwen 0.5B (1-2s) + Parallel search (1s) = **2-3s total**
🟢 < 3s | 🟡 3-5s | 🔴 > 5s
""")
if __name__ == "__main__":
demo.queue(max_size=5)
demo.launch()