File size: 6,099 Bytes
e5b1065
8340e2c
 
01aa62d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8340e2c
e5b1065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8340e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5b1065
 
 
 
 
8340e2c
 
 
e5b1065
 
8340e2c
 
 
 
 
 
 
 
e5b1065
 
8340e2c
 
 
 
 
 
 
e5b1065
 
8340e2c
e5b1065
8340e2c
 
 
 
 
 
 
 
e5b1065
 
8340e2c
 
e5b1065
8340e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os, time
import gradio as gr
from transformers import pipeline
import torch

def make_pipe(task, model_id, fp16_ok=False):
    if torch.cuda.is_available():          # GPU β†’ use device 0
        kwargs = {"device": 0}
        if fp16_ok:
            kwargs["model_kwargs"] = {"torch_dtype": torch.float16}
    else:                                   # CPU β†’ no device arg
        kwargs = {}
    return pipeline(task, model=model_id, **kwargs)

# examples:
asr         = make_pipe("automatic-speech-recognition", "openai/whisper-tiny", fp16_ok=True)
zsc         = make_pipe("zero-shot-classification",     "facebook/bart-large-mnli")
summarizer  = make_pipe("summarization",                 "sshleifer/distilbart-cnn-12-6", fp16_ok=True)
ocr         = make_pipe("image-to-text",                 "microsoft/trocr-small-printed", fp16_ok=True)
qa          = make_pipe("question-answering",            "deepset/roberta-base-squad2")

# Lighter defaults (you can override via Space Secrets/Env)
ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-tiny")
ZSC_MODEL = os.getenv("ZSC_MODEL", "typeform/distilbert-base-uncased-mnli")
SUM_MODEL = os.getenv("SUM_MODEL", "sshleifer/distilbart-cnn-12-6")
OCR_MODEL = os.getenv("OCR_MODEL", "microsoft/trocr-small-printed")
QA_MODEL  = os.getenv("QA_MODEL",  "distilbert-base-uncased-distilled-squad")

_asr = _zsc = _summ = _ocr = _qa = None

def get_asr():
    global _asr
    if _asr is None:
        _asr = pipeline("automatic-speech-recognition", model=ASR_MODEL)
    return _asr

def get_zsc():
    global _zsc
    if _zsc is None:
        _zsc = pipeline("zero-shot-classification", model=ZSC_MODEL)
    return _zsc

def get_summarizer():
    global _summ
    if _summ is None:
        _summ = pipeline("summarization", model=SUM_MODEL)
    return _summ

def get_ocr():
    global _ocr
    if _ocr is None:
        _ocr = pipeline("image-to-text", model=OCR_MODEL)
    return _ocr

def get_qa():
    global _qa
    if _qa is None:
        _qa = pipeline("question-answering", model=QA_MODEL)
    return _qa

DEFAULT_INTENTS = [
    "turn_on_lights","turn_off_lights","volume_up","volume_down",
    "start_music","pause_music","set_timer","cancel_timer",
    "open_calendar","create_note","start_recording","stop_recording"
]

TOOLS = {
    "turn_on_lights": lambda: "Lights β†’ ON",
    "turn_off_lights": lambda: "Lights β†’ OFF",
    "volume_up": lambda: "Volume β†’ UP",
    "volume_down": lambda: "Volume β†’ DOWN",
    "start_music": lambda: "Music β†’ PLAY",
    "pause_music": lambda: "Music β†’ PAUSE",
    "set_timer": lambda: "Timer β†’ 5 min (demo)",
    "cancel_timer": lambda: "Timer β†’ CANCELLED",
    "open_calendar": lambda: "Calendar β†’ OPENED",
    "create_note": lambda text="": f"Note saved: '{text[:60]}'",
    "start_recording": lambda: "Recording β†’ STARTED",
    "stop_recording": lambda: "Recording β†’ STOPPED",
}

def parse_intents(custom):
    if not custom or not custom.strip():
        return DEFAULT_INTENTS
    return [t.strip() for t in custom.split(",") if t.strip()]

def agent(audio_path, custom_intents, history):
    if not audio_path:
        return gr.update(), gr.update(), "No audio.", history
    asr = get_asr()
    zsc = get_zsc()
    transcript = asr(audio_path)["text"].strip()
    if not transcript:
        return gr.update(), gr.update(), "No speech detected.", history
    intents = parse_intents(custom_intents)
    out = zsc(transcript, candidate_labels=intents, multi_label=False)
    labels, scores = out["labels"], out["scores"]
    top3 = {labels[i]: float(scores[i]) for i in range(min(3, len(labels)))}
    chosen = labels[0]
    result = (TOOLS[chosen](transcript) if chosen == "create_note"
              else TOOLS.get(chosen, lambda: f"No tool bound: {chosen}")())
    stamp = time.strftime("%H:%M:%S")
    history = history + [[f"User: {transcript}", f"{stamp} β€’ {chosen} β†’ {result}"]]
    return top3, chosen, result, history

def do_ocr(image):
    if image is None:
        return "", ""
    ocr = get_ocr()
    summarizer = get_summarizer()
    text = ocr(image)[0]["generated_text"]
    if not text.strip():
        return "", ""
    chunk = text[:3000]
    summary = summarizer(chunk, max_length=120, min_length=30, do_sample=False)[0]["summary_text"]
    return text, summary

def ask_qa(context_text, question):
    if not context_text or not question:
        return ""
    qa = get_qa()
    return qa({"context": context_text, "question": question}).get("answer", "")

with gr.Blocks(title="Multimodal Voice & OCR Agent") as demo:
    gr.Markdown("## 🎀🧾 Multimodal Voice & OCR Agent\nUses **pre-trained models** only. Models are loaded lazily per tab to reduce RAM.")

    with gr.Tabs():
        with gr.Tab("Voice Agent"):
            with gr.Row():
                audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio")
                intents_box = gr.Textbox(label="Intents (comma-separated)", value=", ".join(DEFAULT_INTENTS))
            run = gr.Button("Run")
            topk = gr.Label(num_top_classes=3, label="Top-k Intents")
            chosen = gr.Textbox(label="Chosen Intent")
            result = gr.Textbox(label="Action Result")
            chat = gr.Chatbot(label="Execution Log")
            state = gr.State([])
            run.click(agent, inputs=[audio, intents_box, state], outputs=[topk, chosen, result, chat], queue=True)

        with gr.Tab("OCR + Summarize + QA"):
            img = gr.Image(type="filepath", label="Upload an image / screenshot / page")
            ocr_btn = gr.Button("Extract text + Summarize")
            ocr_text = gr.Textbox(label="OCR Text", lines=10)
            ocr_sum  = gr.Textbox(label="Summary", lines=6)
            with gr.Row():
                question = gr.Textbox(label="Ask a question about the OCR text")
                qa_btn   = gr.Button("Answer")
                answer   = gr.Textbox(label="Answer")
            ocr_btn.click(do_ocr, inputs=img, outputs=[ocr_text, ocr_sum])
            qa_btn.click(ask_qa, inputs=[ocr_text, question], outputs=answer)