|
|
import gradio as gr |
|
|
from faster_whisper import WhisperModel |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
import requests |
|
|
import base64 |
|
|
import tempfile |
|
|
import os |
|
|
import logging |
|
|
import time |
|
|
from datetime import datetime |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from html.parser import HTMLParser |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
logger.info("Loading Whisper-tiny...") |
|
|
whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8") |
|
|
|
|
|
logger.info("Loading SmolLM2-360M-Instruct...") |
|
|
model_name = "HuggingFaceTB/SmolLM2-360M-Instruct" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float32, |
|
|
device_map="cpu", |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
logger.info("All models loaded!") |
|
|
|
|
|
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY', '') |
|
|
BRAVE_API_KEY = os.getenv('BRAVE_API_KEY', '') |
|
|
|
|
|
def search_tavily(query): |
|
|
if not TAVILY_API_KEY: |
|
|
return None |
|
|
try: |
|
|
response = requests.post( |
|
|
'https://api.tavily.com/search', |
|
|
json={'api_key': TAVILY_API_KEY, 'query': query, 'max_results': 2}, |
|
|
timeout=1.5 |
|
|
) |
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
results = data.get('results', []) |
|
|
return "\n".join([f"• {r.get('title', '')}: {r.get('content', '')[:120]}" for r in results[:2]]) |
|
|
except: |
|
|
pass |
|
|
return None |
|
|
|
|
|
def search_brave(query): |
|
|
if not BRAVE_API_KEY: |
|
|
return None |
|
|
try: |
|
|
response = requests.get( |
|
|
'https://api.search.brave.com/res/v1/web/search', |
|
|
params={'q': query, 'count': 2}, |
|
|
headers={'X-Subscription-Token': BRAVE_API_KEY}, |
|
|
timeout=1.5 |
|
|
) |
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
results = data.get('web', {}).get('results', []) |
|
|
return "\n".join([f"• {r.get('title', '')}: {r.get('description', '')[:120]}" for r in results[:2]]) |
|
|
except: |
|
|
pass |
|
|
return None |
|
|
|
|
|
def search_searx(query): |
|
|
for instance in ['https://searx.be/search', 'https://searx.work/search']: |
|
|
try: |
|
|
response = requests.get( |
|
|
instance, |
|
|
params={'q': query, 'format': 'json', 'categories': 'general', 'language': 'en'}, |
|
|
timeout=1.5 |
|
|
) |
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
results = data.get('results', []) |
|
|
return "\n".join([f"• {r.get('title', '')}: {r.get('content', '')[:120]}" for r in results[:2]]) |
|
|
except: |
|
|
continue |
|
|
return None |
|
|
|
|
|
def search_duckduckgo(query): |
|
|
try: |
|
|
response = requests.get( |
|
|
'https://html.duckduckgo.com/html/', |
|
|
params={'q': query}, |
|
|
headers={'User-Agent': 'Mozilla/5.0'}, |
|
|
timeout=1.5 |
|
|
) |
|
|
if response.status_code == 200: |
|
|
class DDGParser(HTMLParser): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.results = [] |
|
|
self.in_result = False |
|
|
self.current_text = "" |
|
|
|
|
|
def handle_starttag(self, tag, attrs): |
|
|
if tag == 'a' and any(k == 'class' and 'result__a' in v for k, v in attrs): |
|
|
self.in_result = True |
|
|
|
|
|
def handle_data(self, data): |
|
|
if self.in_result and data.strip(): |
|
|
self.current_text += data.strip() + " " |
|
|
|
|
|
def handle_endtag(self, tag): |
|
|
if tag == 'a' and self.in_result: |
|
|
if self.current_text: |
|
|
self.results.append(self.current_text.strip()[:120]) |
|
|
self.current_text = "" |
|
|
self.in_result = False |
|
|
|
|
|
parser = DDGParser() |
|
|
parser.feed(response.text) |
|
|
return "\n".join([f"• {r}" for r in parser.results[:2]]) if parser.results else None |
|
|
except: |
|
|
pass |
|
|
return None |
|
|
|
|
|
def search_parallel(query): |
|
|
logger.info("[SEARCH] Starting parallel search...") |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=4) as executor: |
|
|
futures = { |
|
|
executor.submit(search_tavily, query): "Tavily", |
|
|
executor.submit(search_brave, query): "Brave", |
|
|
executor.submit(search_searx, query): "Searx", |
|
|
executor.submit(search_duckduckgo, query): "DuckDuckGo" |
|
|
} |
|
|
|
|
|
for future in futures: |
|
|
engine = futures[future] |
|
|
try: |
|
|
result = future.result(timeout=2) |
|
|
if result: |
|
|
logger.info(f"[SEARCH] ✓ {engine}") |
|
|
return result, engine |
|
|
except: |
|
|
pass |
|
|
|
|
|
logger.warning("[SEARCH] All engines failed") |
|
|
return "No search results available.", "None" |
|
|
|
|
|
def transcribe_audio_base64(audio_base64): |
|
|
logger.info("[STT] Processing audio...") |
|
|
try: |
|
|
audio_bytes = base64.b64decode(audio_base64) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: |
|
|
temp_audio.write(audio_bytes) |
|
|
temp_path = temp_audio.name |
|
|
|
|
|
segments, _ = whisper_model.transcribe(temp_path, language="en", beam_size=1) |
|
|
transcription = " ".join([seg.text for seg in segments]) |
|
|
os.unlink(temp_path) |
|
|
|
|
|
logger.info("[STT] ✓ Transcribed") |
|
|
return {"text": transcription.strip()} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[STT] Error: {str(e)}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
def generate_answer(text_input): |
|
|
"""Main answer generation - with debug logging""" |
|
|
logger.info("="*60) |
|
|
logger.info(f"[AI] Raw input: '{text_input}'") |
|
|
logger.info(f"[AI] Input type: {type(text_input)}, Length: {len(text_input) if text_input else 0}") |
|
|
|
|
|
try: |
|
|
|
|
|
if not text_input or text_input.strip() in ["", "{{TEXT}}", "{{text}}", "$TEXT"]: |
|
|
error_msg = "❌ ERROR: No question received. Pluely sent empty/template variable.\n\nPluely Config Issue:\n- Check your curl command uses correct format\n- Make sure variable substitution is enabled" |
|
|
logger.error(f"[AI] {error_msg}") |
|
|
return error_msg |
|
|
|
|
|
current_date = datetime.now().strftime("%B %d, %Y") |
|
|
|
|
|
|
|
|
search_start = time.time() |
|
|
search_results, search_engine = search_parallel(text_input) |
|
|
search_time = time.time() - search_start |
|
|
logger.info(f"[AI] Search completed in {search_time:.2f}s") |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": f"You are a helpful assistant. Today is {current_date}. Answer questions using the provided search results. Be concise (60-80 words). Use bullet points for multiple items." |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": f"Search Results:\n{search_results}\n\nQuestion: {text_input}\n\nAnswer based strictly on search results (60-80 words):" |
|
|
} |
|
|
] |
|
|
|
|
|
prompt = f"<|im_start|>system\n{messages[0]['content']}<|im_end|>\n<|im_start|>user\n{messages[1]['content']}<|im_end|>\n<|im_start|>assistant\n" |
|
|
|
|
|
gen_start = time.time() |
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=800) |
|
|
|
|
|
logger.info("[AI] Generating answer...") |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=80, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
top_k=40, |
|
|
repetition_penalty=1.15, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
gen_time = time.time() - gen_start |
|
|
logger.info(f"[AI] Generation completed in {gen_time:.2f}s") |
|
|
|
|
|
answer = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip() |
|
|
full_answer = f"{answer}\n\n**Source:** {search_engine}" |
|
|
|
|
|
logger.info("[AI] ✓ Complete") |
|
|
logger.info("="*60) |
|
|
return full_answer |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[AI] Error: {str(e)}") |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
def process_audio(audio_path, question_text): |
|
|
start_time = time.time() |
|
|
logger.info("="*50) |
|
|
logger.info("[MAIN] New request received") |
|
|
|
|
|
if audio_path: |
|
|
logger.info(f"[MAIN] Processing audio: {audio_path}") |
|
|
try: |
|
|
segments, _ = whisper_model.transcribe(audio_path, language="en", beam_size=1) |
|
|
question = " ".join([seg.text for seg in segments]) |
|
|
logger.info(f"[MAIN] Transcribed: {question}") |
|
|
except Exception as e: |
|
|
logger.error(f"[MAIN] Transcription failed: {str(e)}") |
|
|
return f"❌ Transcription error: {str(e)}", 0.0 |
|
|
else: |
|
|
question = question_text |
|
|
logger.info(f"[MAIN] Text input: {question}") |
|
|
|
|
|
if not question or not question.strip(): |
|
|
logger.warning("[MAIN] No input provided") |
|
|
return "❌ No input provided", 0.0 |
|
|
|
|
|
transcription_time = time.time() - start_time |
|
|
|
|
|
gen_start = time.time() |
|
|
answer = generate_answer(question) |
|
|
gen_time = time.time() - gen_start |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
time_emoji = "🟢" if total_time < 2.0 else "🟡" if total_time < 3.0 else "🔴" |
|
|
|
|
|
timing = f"\n\n{time_emoji} **Performance:** Trans={transcription_time:.2f}s | Search+Gen={gen_time:.2f}s | **Total={total_time:.2f}s**" |
|
|
|
|
|
logger.info(f"[MAIN] Total time: {total_time:.2f}s") |
|
|
logger.info("="*50) |
|
|
|
|
|
return answer + timing, total_time |
|
|
|
|
|
def audio_handler(audio_path): |
|
|
return process_audio(audio_path, None) |
|
|
|
|
|
def text_handler(text_input): |
|
|
return process_audio(None, text_input) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Ultra-Fast Q&A - SmolLM2-360M", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# ⚡ Ultra-Fast Political Q&A System |
|
|
**SmolLM2-360M** (250-400 tok/s) + **Parallel Search** |
|
|
""") |
|
|
|
|
|
with gr.Tab("🎙️ Audio Input"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio") |
|
|
audio_submit = gr.Button("🚀 Submit", variant="primary") |
|
|
with gr.Column(): |
|
|
audio_output = gr.Textbox(label="Answer", lines=10, show_copy_button=True) |
|
|
audio_time = gr.Number(label="Time (s)", precision=2) |
|
|
|
|
|
audio_submit.click(fn=audio_handler, inputs=[audio_input], outputs=[audio_output, audio_time], api_name="audio_query") |
|
|
|
|
|
with gr.Tab("✍️ Text Input"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
text_input = gr.Textbox(label="Question", placeholder="Ask anything...", lines=3) |
|
|
text_submit = gr.Button("🚀 Submit", variant="primary") |
|
|
with gr.Column(): |
|
|
text_output = gr.Textbox(label="Answer", lines=10, show_copy_button=True) |
|
|
text_time = gr.Number(label="Time (s)", precision=2) |
|
|
|
|
|
text_submit.click(fn=text_handler, inputs=[text_input], outputs=[text_output, text_time], api_name="text_query") |
|
|
|
|
|
gr.Examples(examples=[["Who is the US president?"]], inputs=text_input) |
|
|
|
|
|
with gr.Tab("🔌 Pluely API"): |
|
|
gr.Markdown(""" |
|
|
## ⚠️ IMPORTANT: Pluely Configuration |
|
|
|
|
|
### If you see "{{TEXT}}" in logs, try these formats: |
|
|
|
|
|
**Format 1 (Windows CMD - Use This First):** |
|
|
``` |
|
|
curl -X POST https://archcoder-basic-app.hf.space/call/answer_ai -H "Content-Type: application/json" -d "{\\"data\\": [\\"TEXT_PLACEHOLDER\\"]}" |
|
|
``` |
|
|
Then in Pluely, replace `TEXT_PLACEHOLDER` with `{{TEXT}}` |
|
|
|
|
|
**Format 2 (Alternative):** |
|
|
``` |
|
|
curl -X POST https://archcoder-basic-app.hf.space/call/answer_ai -H "Content-Type: application/json" --data-binary "{\\"data\\": [\\"{{TEXT}}\\"]}" |
|
|
``` |
|
|
|
|
|
**Response Path:** `data[0]` |
|
|
|
|
|
--- |
|
|
|
|
|
### STT Endpoint: |
|
|
``` |
|
|
curl -X POST https://archcoder-basic-app.hf.space/call/transcribe_stt -H "Content-Type: application/json" -d "{\\"data\\": [\\"{{AUDIO_BASE64}}\\"]}" |
|
|
``` |
|
|
**Response Path:** `data[0].text` |
|
|
""") |
|
|
|
|
|
with gr.Row(visible=False): |
|
|
stt_in = gr.Textbox() |
|
|
stt_out = gr.JSON() |
|
|
ai_in = gr.Textbox() |
|
|
ai_out = gr.Textbox() |
|
|
|
|
|
gr.Button("STT", visible=False).click(fn=transcribe_audio_base64, inputs=[stt_in], outputs=[stt_out], api_name="transcribe_stt") |
|
|
gr.Button("AI", visible=False).click(fn=generate_answer, inputs=[ai_in], outputs=[ai_out], api_name="answer_ai") |
|
|
|
|
|
gr.Markdown("🟢 < 2s | 🟡 2-3s | 🔴 > 3s") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=5) |
|
|
demo.launch() |
|
|
|