basic_app / app.py
ArchCoder's picture
Update app.py
c2c3825 verified
raw
history blame
13.4 kB
import gradio as gr
from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import requests
import base64
import tempfile
import os
import logging
import time
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from html.parser import HTMLParser
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
logger = logging.getLogger(__name__)
# Initialize models
logger.info("Loading Whisper-tiny...")
whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8")
logger.info("Loading SmolLM2-360M-Instruct...")
model_name = "HuggingFaceTB/SmolLM2-360M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True
)
logger.info("All models loaded!")
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY', '')
BRAVE_API_KEY = os.getenv('BRAVE_API_KEY', '')
def search_tavily(query):
if not TAVILY_API_KEY:
return None
try:
response = requests.post(
'https://api.tavily.com/search',
json={'api_key': TAVILY_API_KEY, 'query': query, 'max_results': 2},
timeout=1.5
)
if response.status_code == 200:
data = response.json()
results = data.get('results', [])
return "\n".join([f"• {r.get('title', '')}: {r.get('content', '')[:120]}" for r in results[:2]])
except:
pass
return None
def search_brave(query):
if not BRAVE_API_KEY:
return None
try:
response = requests.get(
'https://api.search.brave.com/res/v1/web/search',
params={'q': query, 'count': 2},
headers={'X-Subscription-Token': BRAVE_API_KEY},
timeout=1.5
)
if response.status_code == 200:
data = response.json()
results = data.get('web', {}).get('results', [])
return "\n".join([f"• {r.get('title', '')}: {r.get('description', '')[:120]}" for r in results[:2]])
except:
pass
return None
def search_searx(query):
for instance in ['https://searx.be/search', 'https://searx.work/search']:
try:
response = requests.get(
instance,
params={'q': query, 'format': 'json', 'categories': 'general', 'language': 'en'},
timeout=1.5
)
if response.status_code == 200:
data = response.json()
results = data.get('results', [])
return "\n".join([f"• {r.get('title', '')}: {r.get('content', '')[:120]}" for r in results[:2]])
except:
continue
return None
def search_duckduckgo(query):
try:
response = requests.get(
'https://html.duckduckgo.com/html/',
params={'q': query},
headers={'User-Agent': 'Mozilla/5.0'},
timeout=1.5
)
if response.status_code == 200:
class DDGParser(HTMLParser):
def __init__(self):
super().__init__()
self.results = []
self.in_result = False
self.current_text = ""
def handle_starttag(self, tag, attrs):
if tag == 'a' and any(k == 'class' and 'result__a' in v for k, v in attrs):
self.in_result = True
def handle_data(self, data):
if self.in_result and data.strip():
self.current_text += data.strip() + " "
def handle_endtag(self, tag):
if tag == 'a' and self.in_result:
if self.current_text:
self.results.append(self.current_text.strip()[:120])
self.current_text = ""
self.in_result = False
parser = DDGParser()
parser.feed(response.text)
return "\n".join([f"• {r}" for r in parser.results[:2]]) if parser.results else None
except:
pass
return None
def search_parallel(query):
logger.info("[SEARCH] Starting parallel search...")
with ThreadPoolExecutor(max_workers=4) as executor:
futures = {
executor.submit(search_tavily, query): "Tavily",
executor.submit(search_brave, query): "Brave",
executor.submit(search_searx, query): "Searx",
executor.submit(search_duckduckgo, query): "DuckDuckGo"
}
for future in futures:
engine = futures[future]
try:
result = future.result(timeout=2)
if result:
logger.info(f"[SEARCH] ✓ {engine}")
return result, engine
except:
pass
logger.warning("[SEARCH] All engines failed")
return "No search results available.", "None"
def transcribe_audio_base64(audio_base64):
logger.info("[STT] Processing audio...")
try:
audio_bytes = base64.b64decode(audio_base64)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio.write(audio_bytes)
temp_path = temp_audio.name
segments, _ = whisper_model.transcribe(temp_path, language="en", beam_size=1)
transcription = " ".join([seg.text for seg in segments])
os.unlink(temp_path)
logger.info("[STT] ✓ Transcribed")
return {"text": transcription.strip()}
except Exception as e:
logger.error(f"[STT] Error: {str(e)}")
return {"error": str(e)}
def generate_answer(text_input):
"""Main answer generation - with debug logging"""
logger.info("="*60)
logger.info(f"[AI] Raw input: '{text_input}'")
logger.info(f"[AI] Input type: {type(text_input)}, Length: {len(text_input) if text_input else 0}")
try:
# Handle literal {{TEXT}} from Pluely
if not text_input or text_input.strip() in ["", "{{TEXT}}", "{{text}}", "$TEXT"]:
error_msg = "❌ ERROR: No question received. Pluely sent empty/template variable.\n\nPluely Config Issue:\n- Check your curl command uses correct format\n- Make sure variable substitution is enabled"
logger.error(f"[AI] {error_msg}")
return error_msg
current_date = datetime.now().strftime("%B %d, %Y")
# Search
search_start = time.time()
search_results, search_engine = search_parallel(text_input)
search_time = time.time() - search_start
logger.info(f"[AI] Search completed in {search_time:.2f}s")
# Generate
messages = [
{
"role": "system",
"content": f"You are a helpful assistant. Today is {current_date}. Answer questions using the provided search results. Be concise (60-80 words). Use bullet points for multiple items."
},
{
"role": "user",
"content": f"Search Results:\n{search_results}\n\nQuestion: {text_input}\n\nAnswer based strictly on search results (60-80 words):"
}
]
prompt = f"<|im_start|>system\n{messages[0]['content']}<|im_end|>\n<|im_start|>user\n{messages[1]['content']}<|im_end|>\n<|im_start|>assistant\n"
gen_start = time.time()
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=800)
logger.info("[AI] Generating answer...")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=80,
temperature=0.7,
do_sample=True,
top_p=0.9,
top_k=40,
repetition_penalty=1.15,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
gen_time = time.time() - gen_start
logger.info(f"[AI] Generation completed in {gen_time:.2f}s")
answer = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
full_answer = f"{answer}\n\n**Source:** {search_engine}"
logger.info("[AI] ✓ Complete")
logger.info("="*60)
return full_answer
except Exception as e:
logger.error(f"[AI] Error: {str(e)}")
return f"Error: {str(e)}"
def process_audio(audio_path, question_text):
start_time = time.time()
logger.info("="*50)
logger.info("[MAIN] New request received")
if audio_path:
logger.info(f"[MAIN] Processing audio: {audio_path}")
try:
segments, _ = whisper_model.transcribe(audio_path, language="en", beam_size=1)
question = " ".join([seg.text for seg in segments])
logger.info(f"[MAIN] Transcribed: {question}")
except Exception as e:
logger.error(f"[MAIN] Transcription failed: {str(e)}")
return f"❌ Transcription error: {str(e)}", 0.0
else:
question = question_text
logger.info(f"[MAIN] Text input: {question}")
if not question or not question.strip():
logger.warning("[MAIN] No input provided")
return "❌ No input provided", 0.0
transcription_time = time.time() - start_time
gen_start = time.time()
answer = generate_answer(question)
gen_time = time.time() - gen_start
total_time = time.time() - start_time
time_emoji = "🟢" if total_time < 2.0 else "🟡" if total_time < 3.0 else "🔴"
timing = f"\n\n{time_emoji} **Performance:** Trans={transcription_time:.2f}s | Search+Gen={gen_time:.2f}s | **Total={total_time:.2f}s**"
logger.info(f"[MAIN] Total time: {total_time:.2f}s")
logger.info("="*50)
return answer + timing, total_time
def audio_handler(audio_path):
return process_audio(audio_path, None)
def text_handler(text_input):
return process_audio(None, text_input)
# Gradio Interface
with gr.Blocks(title="Ultra-Fast Q&A - SmolLM2-360M", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ⚡ Ultra-Fast Political Q&A System
**SmolLM2-360M** (250-400 tok/s) + **Parallel Search**
""")
with gr.Tab("🎙️ Audio Input"):
with gr.Row():
with gr.Column():
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
audio_submit = gr.Button("🚀 Submit", variant="primary")
with gr.Column():
audio_output = gr.Textbox(label="Answer", lines=10, show_copy_button=True)
audio_time = gr.Number(label="Time (s)", precision=2)
audio_submit.click(fn=audio_handler, inputs=[audio_input], outputs=[audio_output, audio_time], api_name="audio_query")
with gr.Tab("✍️ Text Input"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Question", placeholder="Ask anything...", lines=3)
text_submit = gr.Button("🚀 Submit", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Answer", lines=10, show_copy_button=True)
text_time = gr.Number(label="Time (s)", precision=2)
text_submit.click(fn=text_handler, inputs=[text_input], outputs=[text_output, text_time], api_name="text_query")
gr.Examples(examples=[["Who is the US president?"]], inputs=text_input)
with gr.Tab("🔌 Pluely API"):
gr.Markdown("""
## ⚠️ IMPORTANT: Pluely Configuration
### If you see "{{TEXT}}" in logs, try these formats:
**Format 1 (Windows CMD - Use This First):**
```
curl -X POST https://archcoder-basic-app.hf.space/call/answer_ai -H "Content-Type: application/json" -d "{\\"data\\": [\\"TEXT_PLACEHOLDER\\"]}"
```
Then in Pluely, replace `TEXT_PLACEHOLDER` with `{{TEXT}}`
**Format 2 (Alternative):**
```
curl -X POST https://archcoder-basic-app.hf.space/call/answer_ai -H "Content-Type: application/json" --data-binary "{\\"data\\": [\\"{{TEXT}}\\"]}"
```
**Response Path:** `data[0]`
---
### STT Endpoint:
```
curl -X POST https://archcoder-basic-app.hf.space/call/transcribe_stt -H "Content-Type: application/json" -d "{\\"data\\": [\\"{{AUDIO_BASE64}}\\"]}"
```
**Response Path:** `data[0].text`
""")
with gr.Row(visible=False):
stt_in = gr.Textbox()
stt_out = gr.JSON()
ai_in = gr.Textbox()
ai_out = gr.Textbox()
gr.Button("STT", visible=False).click(fn=transcribe_audio_base64, inputs=[stt_in], outputs=[stt_out], api_name="transcribe_stt")
gr.Button("AI", visible=False).click(fn=generate_answer, inputs=[ai_in], outputs=[ai_out], api_name="answer_ai")
gr.Markdown("🟢 < 2s | 🟡 2-3s | 🔴 > 3s")
if __name__ == "__main__":
demo.queue(max_size=5)
demo.launch()