|
|
import gradio as gr |
|
|
from faster_whisper import WhisperModel |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
import requests |
|
|
import base64 |
|
|
import tempfile |
|
|
import os |
|
|
import logging |
|
|
import time |
|
|
from datetime import datetime |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from html.parser import HTMLParser |
|
|
from fastapi import FastAPI, Request |
|
|
from fastapi.responses import JSONResponse |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
logger.info("Loading Whisper-tiny...") |
|
|
whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8") |
|
|
|
|
|
logger.info("Loading SmolLM2-360M-Instruct...") |
|
|
model_name = "HuggingFaceTB/SmolLM2-360M-Instruct" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float32, |
|
|
device_map="cpu", |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
logger.info("All models loaded!") |
|
|
|
|
|
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY', '') |
|
|
BRAVE_API_KEY = os.getenv('BRAVE_API_KEY', '') |
|
|
|
|
|
def search_parallel(query): |
|
|
"""Simplified search - just DuckDuckGo for speed""" |
|
|
logger.info("[SEARCH] Starting...") |
|
|
try: |
|
|
response = requests.get( |
|
|
'https://html.duckduckgo.com/html/', |
|
|
params={'q': query}, |
|
|
headers={'User-Agent': 'Mozilla/5.0'}, |
|
|
timeout=1.5 |
|
|
) |
|
|
if response.status_code == 200: |
|
|
class DDGParser(HTMLParser): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.results = [] |
|
|
self.in_result = False |
|
|
self.current_text = "" |
|
|
|
|
|
def handle_starttag(self, tag, attrs): |
|
|
if tag == 'a' and any(k == 'class' and 'result__a' in v for k, v in attrs): |
|
|
self.in_result = True |
|
|
|
|
|
def handle_data(self, data): |
|
|
if self.in_result and data.strip(): |
|
|
self.current_text += data.strip() + " " |
|
|
|
|
|
def handle_endtag(self, tag): |
|
|
if tag == 'a' and self.in_result: |
|
|
if self.current_text: |
|
|
self.results.append(self.current_text.strip()[:120]) |
|
|
self.current_text = "" |
|
|
self.in_result = False |
|
|
|
|
|
parser = DDGParser() |
|
|
parser.feed(response.text) |
|
|
result = "\n".join([f"• {r}" for r in parser.results[:2]]) if parser.results else "No results" |
|
|
logger.info("[SEARCH] ✓") |
|
|
return result, "DuckDuckGo" |
|
|
except: |
|
|
pass |
|
|
return "No search results", "None" |
|
|
|
|
|
def generate_answer(text_input): |
|
|
"""Main answer generation""" |
|
|
logger.info(f"[AI] Question: {text_input[:60]}...") |
|
|
|
|
|
try: |
|
|
if not text_input or not text_input.strip(): |
|
|
return "No input provided" |
|
|
|
|
|
current_date = datetime.now().strftime("%B %d, %Y") |
|
|
|
|
|
|
|
|
search_start = time.time() |
|
|
search_results, search_engine = search_parallel(text_input) |
|
|
logger.info(f"[AI] Search: {time.time()-search_start:.2f}s") |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": f"Today is {current_date}. Answer briefly using search results (60-80 words)."}, |
|
|
{"role": "user", "content": f"Search:\n{search_results}\n\nQ: {text_input}\nA:"} |
|
|
] |
|
|
|
|
|
prompt = f"<|im_start|>system\n{messages[0]['content']}<|im_end|>\n<|im_start|>user\n{messages[1]['content']}<|im_end|>\n<|im_start|>assistant\n" |
|
|
|
|
|
gen_start = time.time() |
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=800) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=80, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
top_k=40, |
|
|
repetition_penalty=1.15, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
answer = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip() |
|
|
logger.info(f"[AI] Gen: {time.time()-gen_start:.2f}s | ✓") |
|
|
|
|
|
return f"{answer}\n\n**Source:** {search_engine}" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[AI] Error: {str(e)}") |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
def transcribe_audio_base64(audio_base64): |
|
|
"""Transcribe audio""" |
|
|
logger.info("[STT] Start") |
|
|
try: |
|
|
audio_bytes = base64.b64decode(audio_base64) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: |
|
|
temp_audio.write(audio_bytes) |
|
|
temp_path = temp_audio.name |
|
|
|
|
|
segments, _ = whisper_model.transcribe(temp_path, language="en", beam_size=1) |
|
|
transcription = " ".join([seg.text for seg in segments]) |
|
|
os.unlink(temp_path) |
|
|
|
|
|
logger.info("[STT] ✓") |
|
|
return transcription.strip() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[STT] Error: {str(e)}") |
|
|
return "" |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.post("/api/stt") |
|
|
async def api_stt(request: Request): |
|
|
"""Direct STT endpoint for Pluely""" |
|
|
try: |
|
|
body = await request.json() |
|
|
logger.info(f"[API STT] Received: {body}") |
|
|
|
|
|
audio_base64 = body.get("audio", "") |
|
|
if not audio_base64: |
|
|
return JSONResponse({"error": "No audio data"}, status_code=400) |
|
|
|
|
|
text = transcribe_audio_base64(audio_base64) |
|
|
return JSONResponse({"text": text}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[API STT] Error: {str(e)}") |
|
|
return JSONResponse({"error": str(e)}, status_code=500) |
|
|
|
|
|
@app.post("/api/ai") |
|
|
async def api_ai(request: Request): |
|
|
"""Direct AI endpoint for Pluely""" |
|
|
try: |
|
|
body = await request.json() |
|
|
logger.info(f"[API AI] Received: {body}") |
|
|
|
|
|
question = body.get("text", "") |
|
|
if not question: |
|
|
return JSONResponse({"error": "No text provided"}, status_code=400) |
|
|
|
|
|
answer = generate_answer(question) |
|
|
return JSONResponse({"answer": answer}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[API AI] Error: {str(e)}") |
|
|
return JSONResponse({"error": str(e)}, status_code=500) |
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
"""Health check""" |
|
|
return {"status": "ok", "model": "SmolLM2-360M"} |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Fast Q&A", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# ⚡ Ultra-Fast Q&A System |
|
|
**SmolLM2-360M** + **Direct REST API** for Pluely |
|
|
|
|
|
## Pluely Configuration: |
|
|
|
|
|
### STT Endpoint: |
|
|
``` |
|
|
curl -X POST https://archcoder-basic-app.hf.space/api/stt -H "Content-Type: application/json" -d '{"audio": "{{AUDIO_BASE64}}"}' |
|
|
``` |
|
|
**Response Path:** `text` |
|
|
|
|
|
### AI Endpoint: |
|
|
``` |
|
|
curl -X POST https://archcoder-basic-app.hf.space/api/ai -H "Content-Type: application/json" -d '{"text": "{{TEXT}}"}' |
|
|
``` |
|
|
**Response Path:** `answer` |
|
|
""") |
|
|
|
|
|
with gr.Tab("Test"): |
|
|
with gr.Row(): |
|
|
test_input = gr.Textbox(label="Question", placeholder="Ask anything...") |
|
|
test_btn = gr.Button("🚀 Test") |
|
|
test_output = gr.Textbox(label="Answer", lines=8) |
|
|
|
|
|
test_btn.click(fn=generate_answer, inputs=[test_input], outputs=[test_output]) |
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|