basic_app / app.py
ArchCoder's picture
Update app.py
7b37201 verified
raw
history blame
13 kB
import gradio as gr
from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import requests
import base64
import tempfile
import os
import logging
import time # ADDED - was missing!
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from html.parser import HTMLParser
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Initialize models
logger.info("Loading Whisper model...")
whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8")
logger.info("Loading Qwen 2.5 1.5B-Instruct...")
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True
)
logger.info("All models loaded successfully!")
# Search APIs configuration
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY', '')
BRAVE_API_KEY = os.getenv('BRAVE_API_KEY', '')
def search_tavily(query):
"""Priority 1: Tavily AI search"""
logger.info("[TAVILY] Starting...")
if not TAVILY_API_KEY:
logger.warning("[TAVILY] No API key")
return None
try:
response = requests.post(
'https://api.tavily.com/search',
json={'api_key': TAVILY_API_KEY, 'query': query, 'max_results': 3},
timeout=3
)
if response.status_code == 200:
data = response.json()
results = data.get('results', [])
context = ""
for i, result in enumerate(results[:3], 1):
context += f"\n[Tavily {i}] {result.get('title', '')}\n{result.get('content', '')}\n"
logger.info(f"[TAVILY] Success - {len(results)} results")
return context
except Exception as e:
logger.error(f"[TAVILY] Error: {str(e)}")
return None
def search_brave(query):
"""Priority 2: Brave Search"""
logger.info("[BRAVE] Starting...")
if not BRAVE_API_KEY:
logger.warning("[BRAVE] No API key")
return None
try:
response = requests.get(
'https://api.search.brave.com/res/v1/web/search',
params={'q': query, 'count': 3},
headers={'X-Subscription-Token': BRAVE_API_KEY},
timeout=3
)
if response.status_code == 200:
data = response.json()
results = data.get('web', {}).get('results', [])
context = ""
for i, result in enumerate(results[:3], 1):
context += f"\n[Brave {i}] {result.get('title', '')}\n{result.get('description', '')}\n"
logger.info(f"[BRAVE] Success - {len(results)} results")
return context
except Exception as e:
logger.error(f"[BRAVE] Error: {str(e)}")
return None
def search_searx(query):
"""Priority 3: Searx"""
logger.info("[SEARX] Starting...")
searx_instances = [
'https://searx.be/search',
'https://searx.work/search',
'https://search.sapti.me/search'
]
for instance in searx_instances:
try:
response = requests.get(
instance,
params={'q': query, 'format': 'json', 'categories': 'general'},
timeout=3
)
if response.status_code == 200:
data = response.json()
results = data.get('results', [])
context = ""
for i, result in enumerate(results[:3], 1):
context += f"\n[Searx {i}] {result.get('title', '')}\n{result.get('content', '')}\n"
logger.info(f"[SEARX] Success from {instance}")
return context
except Exception as e:
logger.warning(f"[SEARX] Failed {instance}: {str(e)}")
return None
def search_duckduckgo_html(query):
"""Priority 4: DuckDuckGo HTML"""
logger.info("[DDG] Starting...")
try:
response = requests.get(
'https://html.duckduckgo.com/html/',
params={'q': query},
headers={'User-Agent': 'Mozilla/5.0'},
timeout=3
)
if response.status_code == 200:
class DDGParser(HTMLParser):
def __init__(self):
super().__init__()
self.results = []
self.in_result = False
self.current_text = ""
def handle_starttag(self, tag, attrs):
if tag == 'a' and any(k == 'class' and 'result__a' in v for k, v in attrs):
self.in_result = True
def handle_data(self, data):
if self.in_result:
self.current_text += data.strip()
def handle_endtag(self, tag):
if tag == 'a' and self.in_result:
self.results.append(self.current_text)
self.current_text = ""
self.in_result = False
parser = DDGParser()
parser.feed(response.text)
context = ""
for i, result in enumerate(parser.results[:3], 1):
context += f"\n[DDG {i}] {result}\n"
if context:
logger.info(f"[DDG] Success")
return context
except Exception as e:
logger.error(f"[DDG] Error: {str(e)}")
return None
def search_parallel(query):
"""Execute all searches in parallel"""
logger.info("[PARALLEL] Starting all engines...")
with ThreadPoolExecutor(max_workers=4) as executor:
futures = {
executor.submit(search_tavily, query): "Tavily",
executor.submit(search_brave, query): "Brave",
executor.submit(search_searx, query): "Searx",
executor.submit(search_duckduckgo_html, query): "DuckDuckGo"
}
priority_order = ["Tavily", "Brave", "Searx", "DuckDuckGo"]
results = {}
for future in futures:
engine = futures[future]
try:
result = future.result(timeout=4)
if result:
results[engine] = result
logger.info(f"[PARALLEL] {engine} completed")
except Exception as e:
logger.error(f"[PARALLEL] {engine} failed: {str(e)}")
for engine in priority_order:
if engine in results and results[engine]:
logger.info(f"[PARALLEL] Using {engine}")
return results[engine], engine
logger.error("[PARALLEL] All failed")
return "Unable to fetch search results.", "None"
def transcribe_audio_base64(audio_base64):
"""Transcribe audio"""
logger.info("[PLUELY STT] Request")
try:
audio_bytes = base64.b64decode(audio_base64)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio.write(audio_bytes)
temp_path = temp_audio.name
segments, _ = whisper_model.transcribe(temp_path, language="en", beam_size=1)
transcription = " ".join([seg.text for seg in segments])
os.unlink(temp_path)
logger.info(f"[PLUELY STT] Success")
return {"text": transcription.strip()}
except Exception as e:
logger.error(f"[PLUELY STT] Error: {str(e)}")
return {"error": str(e)}
def generate_answer(text_input):
"""Generate answer"""
logger.info(f"[PLUELY AI] Question: {text_input}")
try:
if not text_input or not text_input.strip():
return "No input provided"
current_date = datetime.now().strftime("%B %d, %Y")
logger.info("[PLUELY AI] Searching...")
search_results, search_engine = search_parallel(text_input)
logger.info(f"[PLUELY AI] Using {search_engine}")
messages = [
{"role": "system", "content": f"Today is {current_date}. Answer using ONLY the search results. Be concise (100-120 words)."},
{"role": "user", "content": f"Search Results:\n{search_results}\n\nQuestion: {text_input}\n\nAnswer based strictly on search results:"}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
logger.info("[PLUELY AI] Generating...")
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1500)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.4,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id
)
answer = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
answer_with_source = f"{answer}\n\n**Source:** {search_engine}"
logger.info(f"[PLUELY AI] Done")
return answer_with_source
except Exception as e:
logger.error(f"[PLUELY AI] Error: {str(e)}")
return f"Error: {str(e)}"
def process_audio(audio_path, question_text):
"""Main pipeline"""
start_time = time.time()
logger.info("="*50)
if audio_path:
try:
segments, _ = whisper_model.transcribe(audio_path, language="en", beam_size=1)
question = " ".join([seg.text for seg in segments])
except Exception as e:
return f"❌ Error: {str(e)}", 0.0
else:
question = question_text
if not question or not question.strip():
return "❌ No input", 0.0
answer = generate_answer(question)
total_time = time.time() - start_time
time_emoji = "🟢" if total_time < 4.0 else "🟡" if total_time < 6.0 else "🔴"
timing = f"\n\n{time_emoji} **Time:** {total_time:.2f}s"
logger.info(f"[MAIN] Total: {total_time:.2f}s")
logger.info("="*50)
return answer + timing, total_time
def audio_handler(audio_path):
return process_audio(audio_path, None)
def text_handler(text_input):
return process_audio(None, text_input)
# Gradio UI
with gr.Blocks(title="Fast Q&A", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ⚡ Fast Political Q&A
**Parallel multi-search + Qwen 2.5 1.5B**
""")
with gr.Tab("🎙️ Audio"):
with gr.Row():
with gr.Column():
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath")
audio_submit = gr.Button("🚀 Submit", variant="primary", size="lg")
with gr.Column():
audio_output = gr.Textbox(label="Answer", lines=10, show_copy_button=True)
audio_time = gr.Number(label="Time (s)", precision=2)
audio_submit.click(fn=audio_handler, inputs=[audio_input], outputs=[audio_output, audio_time], api_name="audio_query")
with gr.Tab("✍️ Text"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Question", placeholder="Ask anything...", lines=3)
text_submit = gr.Button("🚀 Submit", variant="primary", size="lg")
with gr.Column():
text_output = gr.Textbox(label="Answer", lines=10, show_copy_button=True)
text_time = gr.Number(label="Time (s)", precision=2)
text_submit.click(fn=text_handler, inputs=[text_input], outputs=[text_output, text_time], api_name="text_query")
gr.Examples(
examples=[
["Is internet shut down in Bareilly today?"],
["Who won 2024 US election?"]
],
inputs=text_input
)
with gr.Tab("🔌 API"):
gr.Markdown("""
**Pluely Endpoints:**
STT: `https://archcoder-basic-app.hf.space/call/transcribe_stt`
AI: `https://archcoder-basic-app.hf.space/call/answer_ai`
**Response Paths:**
STT: `data[0].text`
AI: `data[0]`
""")
with gr.Row(visible=False):
stt_in = gr.Textbox()
stt_out = gr.JSON()
ai_in = gr.Textbox()
ai_out = gr.Textbox()
gr.Button("STT", visible=False).click(fn=transcribe_audio_base64, inputs=[stt_in], outputs=[stt_out], api_name="transcribe_stt")
gr.Button("AI", visible=False).click(fn=generate_answer, inputs=[ai_in], outputs=[ai_out], api_name="answer_ai")
gr.Markdown("🟢 < 4s | 🟡 4-6s | 🔴 > 6s")
if __name__ == "__main__":
demo.queue(max_size=5)
demo.launch()