|
|
import gradio as gr |
|
|
from faster_whisper import WhisperModel |
|
|
from llama_cpp import Llama |
|
|
from brave import Brave |
|
|
import os |
|
|
import time |
|
|
|
|
|
|
|
|
print("Loading models...") |
|
|
whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8") |
|
|
llm = Llama.from_pretrained( |
|
|
repo_id="Qwen/Qwen2.5-0.5B-Instruct-GGUF", |
|
|
filename="qwen2.5-0.5b-instruct-q4_k_m.gguf", |
|
|
n_ctx=2048, |
|
|
n_threads=4, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
|
|
|
brave_client = Brave(api_key=os.getenv("BRAVE_API_KEY", "")) |
|
|
|
|
|
def search_web(query, max_results=3): |
|
|
"""Perform web search using Brave API""" |
|
|
try: |
|
|
results = brave_client.search(q=query, count=max_results) |
|
|
web_results = results.web_results if hasattr(results, 'web_results') else [] |
|
|
|
|
|
context = "" |
|
|
for i, result in enumerate(web_results[:max_results], 1): |
|
|
context += f"\n[{i}] {result.title}\n{result.description}\n" |
|
|
return context.strip() |
|
|
except Exception as e: |
|
|
return f"Search failed: {str(e)}" |
|
|
|
|
|
def process_audio(audio_path, question_text=None): |
|
|
"""Main pipeline: audio -> text -> search -> answer""" |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if audio_path: |
|
|
segments, _ = whisper_model.transcribe(audio_path, language="en") |
|
|
question = " ".join([seg.text for seg in segments]) |
|
|
else: |
|
|
question = question_text |
|
|
|
|
|
if not question: |
|
|
return "No input provided", 0.0 |
|
|
|
|
|
transcription_time = time.time() - start_time |
|
|
|
|
|
|
|
|
search_start = time.time() |
|
|
search_results = search_web(question) |
|
|
search_time = time.time() - search_start |
|
|
|
|
|
|
|
|
llm_start = time.time() |
|
|
prompt = f"""You are a helpful assistant. Answer the question based on the context below. |
|
|
|
|
|
Context from web search: |
|
|
{search_results} |
|
|
|
|
|
Question: {question} |
|
|
|
|
|
Answer briefly and accurately:""" |
|
|
|
|
|
response = llm( |
|
|
prompt, |
|
|
max_tokens=150, |
|
|
temperature=0.3, |
|
|
top_p=0.9, |
|
|
stop=["Question:", "\n\n"], |
|
|
echo=False |
|
|
) |
|
|
|
|
|
answer = response['choices'][0]['text'].strip() |
|
|
llm_time = time.time() - llm_start |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
timing_info = f"\n\n⏱️ Timing: Transcription={transcription_time:.2f}s | Search={search_time:.2f}s | LLM={llm_time:.2f}s | Total={total_time:.2f}s" |
|
|
|
|
|
return answer + timing_info, total_time |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Fast Q&A with Web Search") as demo: |
|
|
gr.Markdown("# 🎤 Fast Political Q&A System\nAsk questions via audio or text. Answers in ~3 seconds!") |
|
|
|
|
|
with gr.Tab("Audio Input"): |
|
|
audio_input = gr.Audio(type="filepath", label="Record or upload audio question") |
|
|
audio_submit = gr.Button("Submit Audio", variant="primary") |
|
|
audio_output = gr.Textbox(label="Answer", lines=6) |
|
|
audio_time = gr.Number(label="Response Time (seconds)") |
|
|
|
|
|
audio_submit.click( |
|
|
fn=lambda x: process_audio(x, None), |
|
|
inputs=[audio_input], |
|
|
outputs=[audio_output, audio_time], |
|
|
api_name="audio_query" |
|
|
) |
|
|
|
|
|
with gr.Tab("Text Input"): |
|
|
text_input = gr.Textbox(label="Type your question", placeholder="Who won the 2024 elections?") |
|
|
text_submit = gr.Button("Submit Text", variant="primary") |
|
|
text_output = gr.Textbox(label="Answer", lines=6) |
|
|
text_time = gr.Number(label="Response Time (seconds)") |
|
|
|
|
|
text_submit.click( |
|
|
fn=lambda x: process_audio(None, x), |
|
|
inputs=[text_input], |
|
|
outputs=[text_output, text_time], |
|
|
api_name="text_query" |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### 📡 API Usage |
|
|
``` |
|
|
# Upload audio file |
|
|
curl -F "files=@audio.mp3" https://YOUR-SPACE-URL/upload |
|
|
|
|
|
# Make query |
|
|
curl -X POST https://YOUR-SPACE-URL/call/audio_query \\ |
|
|
-H "Content-Type: application/json" \\ |
|
|
-d '{"data": [{"path": "/tmp/uploaded_audio.mp3"}]}' |
|
|
``` |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|