Voice_OCR_Agent / app.py
hudaakram's picture
Update app.py
01aa62d verified
import os, time
import gradio as gr
from transformers import pipeline
import torch
def make_pipe(task, model_id, fp16_ok=False):
if torch.cuda.is_available(): # GPU β†’ use device 0
kwargs = {"device": 0}
if fp16_ok:
kwargs["model_kwargs"] = {"torch_dtype": torch.float16}
else: # CPU β†’ no device arg
kwargs = {}
return pipeline(task, model=model_id, **kwargs)
# examples:
asr = make_pipe("automatic-speech-recognition", "openai/whisper-tiny", fp16_ok=True)
zsc = make_pipe("zero-shot-classification", "facebook/bart-large-mnli")
summarizer = make_pipe("summarization", "sshleifer/distilbart-cnn-12-6", fp16_ok=True)
ocr = make_pipe("image-to-text", "microsoft/trocr-small-printed", fp16_ok=True)
qa = make_pipe("question-answering", "deepset/roberta-base-squad2")
# Lighter defaults (you can override via Space Secrets/Env)
ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-tiny")
ZSC_MODEL = os.getenv("ZSC_MODEL", "typeform/distilbert-base-uncased-mnli")
SUM_MODEL = os.getenv("SUM_MODEL", "sshleifer/distilbart-cnn-12-6")
OCR_MODEL = os.getenv("OCR_MODEL", "microsoft/trocr-small-printed")
QA_MODEL = os.getenv("QA_MODEL", "distilbert-base-uncased-distilled-squad")
_asr = _zsc = _summ = _ocr = _qa = None
def get_asr():
global _asr
if _asr is None:
_asr = pipeline("automatic-speech-recognition", model=ASR_MODEL)
return _asr
def get_zsc():
global _zsc
if _zsc is None:
_zsc = pipeline("zero-shot-classification", model=ZSC_MODEL)
return _zsc
def get_summarizer():
global _summ
if _summ is None:
_summ = pipeline("summarization", model=SUM_MODEL)
return _summ
def get_ocr():
global _ocr
if _ocr is None:
_ocr = pipeline("image-to-text", model=OCR_MODEL)
return _ocr
def get_qa():
global _qa
if _qa is None:
_qa = pipeline("question-answering", model=QA_MODEL)
return _qa
DEFAULT_INTENTS = [
"turn_on_lights","turn_off_lights","volume_up","volume_down",
"start_music","pause_music","set_timer","cancel_timer",
"open_calendar","create_note","start_recording","stop_recording"
]
TOOLS = {
"turn_on_lights": lambda: "Lights β†’ ON",
"turn_off_lights": lambda: "Lights β†’ OFF",
"volume_up": lambda: "Volume β†’ UP",
"volume_down": lambda: "Volume β†’ DOWN",
"start_music": lambda: "Music β†’ PLAY",
"pause_music": lambda: "Music β†’ PAUSE",
"set_timer": lambda: "Timer β†’ 5 min (demo)",
"cancel_timer": lambda: "Timer β†’ CANCELLED",
"open_calendar": lambda: "Calendar β†’ OPENED",
"create_note": lambda text="": f"Note saved: '{text[:60]}'",
"start_recording": lambda: "Recording β†’ STARTED",
"stop_recording": lambda: "Recording β†’ STOPPED",
}
def parse_intents(custom):
if not custom or not custom.strip():
return DEFAULT_INTENTS
return [t.strip() for t in custom.split(",") if t.strip()]
def agent(audio_path, custom_intents, history):
if not audio_path:
return gr.update(), gr.update(), "No audio.", history
asr = get_asr()
zsc = get_zsc()
transcript = asr(audio_path)["text"].strip()
if not transcript:
return gr.update(), gr.update(), "No speech detected.", history
intents = parse_intents(custom_intents)
out = zsc(transcript, candidate_labels=intents, multi_label=False)
labels, scores = out["labels"], out["scores"]
top3 = {labels[i]: float(scores[i]) for i in range(min(3, len(labels)))}
chosen = labels[0]
result = (TOOLS[chosen](transcript) if chosen == "create_note"
else TOOLS.get(chosen, lambda: f"No tool bound: {chosen}")())
stamp = time.strftime("%H:%M:%S")
history = history + [[f"User: {transcript}", f"{stamp} β€’ {chosen} β†’ {result}"]]
return top3, chosen, result, history
def do_ocr(image):
if image is None:
return "", ""
ocr = get_ocr()
summarizer = get_summarizer()
text = ocr(image)[0]["generated_text"]
if not text.strip():
return "", ""
chunk = text[:3000]
summary = summarizer(chunk, max_length=120, min_length=30, do_sample=False)[0]["summary_text"]
return text, summary
def ask_qa(context_text, question):
if not context_text or not question:
return ""
qa = get_qa()
return qa({"context": context_text, "question": question}).get("answer", "")
with gr.Blocks(title="Multimodal Voice & OCR Agent") as demo:
gr.Markdown("## 🎀🧾 Multimodal Voice & OCR Agent\nUses **pre-trained models** only. Models are loaded lazily per tab to reduce RAM.")
with gr.Tabs():
with gr.Tab("Voice Agent"):
with gr.Row():
audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio")
intents_box = gr.Textbox(label="Intents (comma-separated)", value=", ".join(DEFAULT_INTENTS))
run = gr.Button("Run")
topk = gr.Label(num_top_classes=3, label="Top-k Intents")
chosen = gr.Textbox(label="Chosen Intent")
result = gr.Textbox(label="Action Result")
chat = gr.Chatbot(label="Execution Log")
state = gr.State([])
run.click(agent, inputs=[audio, intents_box, state], outputs=[topk, chosen, result, chat], queue=True)
with gr.Tab("OCR + Summarize + QA"):
img = gr.Image(type="filepath", label="Upload an image / screenshot / page")
ocr_btn = gr.Button("Extract text + Summarize")
ocr_text = gr.Textbox(label="OCR Text", lines=10)
ocr_sum = gr.Textbox(label="Summary", lines=6)
with gr.Row():
question = gr.Textbox(label="Ask a question about the OCR text")
qa_btn = gr.Button("Answer")
answer = gr.Textbox(label="Answer")
ocr_btn.click(do_ocr, inputs=img, outputs=[ocr_text, ocr_sum])
qa_btn.click(ask_qa, inputs=[ocr_text, question], outputs=answer)