File size: 3,323 Bytes
a7b143d
cb277bc
602cad6
cb277bc
602cad6
 
 
cb277bc
602cad6
 
 
 
cb277bc
602cad6
 
 
 
cb277bc
602cad6
 
 
 
 
cb277bc
602cad6
 
 
 
cb277bc
602cad6
 
 
 
 
 
 
 
 
cb277bc
602cad6
 
cb277bc
602cad6
 
 
 
 
 
 
 
 
 
 
 
 
a7b143d
602cad6
cb277bc
 
 
602cad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7b143d
cb277bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# app.py (fixed, no concurrency_count)
import os, sys, time, traceback, subprocess
from typing import Tuple, Optional
from PIL import Image

try:
    import gradio as gr
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "gradio"])
    import gradio as gr

def _make_fallback():
    def _fallback_answer_with_controller(image, question, source="auto", distilled_model="auto"):
        return "Placeholder answer (wire your models in controller.py).", "baseline", 0
    return _fallback_answer_with_controller

try:
    from controller import answer_with_controller
except Exception as e:
    print(f"[WARN] Using fallback controller because import failed: {e}", flush=True)
    answer_with_controller = _make_fallback()

TITLE = "VQA — Memory + RL Controller"
DESCRIPTION = "Upload an image, enter a question, and the controller will choose the best decoding strategy."

CONTROLLER_SOURCES = ["auto", "distilled", "ppo", "baseline"]
DISTILLED_CHOICES = ["auto", "logreg", "mlp32"]

def vqa_demo_fn(image: Optional[Image.Image], question: str, source: str, distilled_model: str) -> Tuple[str, str, float]:
    if image is None:
        return "Please upload an image.", "", 0.0
    question = (question or "").strip()
    if not question:
        return "Please enter a question.", "", 0.0
    t0 = time.perf_counter()
    try:
        image_rgb = image.convert("RGB")
        pred, strategy_name, action_id = answer_with_controller(
            image_rgb, question, source=source, distilled_model=distilled_model
        )
        latency_ms = (time.perf_counter() - t0) * 1000.0
        return str(pred), f"{action_id}{strategy_name}", round(latency_ms, 1)
    except Exception as err:
        latency_ms = (time.perf_counter() - t0) * 1000.0
        print("[ERROR] Inference failed:\n" + "".join(traceback.format_exc()), flush=True)
        return f"Error: {err}", "error", round(latency_ms, 1)

with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
    gr.Markdown(f"### {TITLE}\n{DESCRIPTION}")
    with gr.Row():
        with gr.Column():
            img_in = gr.Image(
                type="pil",
                label="Image",
                height=320,
                sources=["upload", "webcam", "clipboard"],  # valid
            )
            q_in = gr.Textbox(label="Question", placeholder="e.g., What colour is the bus?", lines=2, max_lines=4)
            source_in = gr.Radio(CONTROLLER_SOURCES, value="auto", label="Controller Source")
            dist_in = gr.Radio(DISTILLED_CHOICES, value="auto", label="Distilled Gate (if used)")
            run_btn = gr.Button("Predict", variant="primary")
        with gr.Column():
            ans_out = gr.Textbox(label="Answer", interactive=False, lines=3, max_lines=6)
            strat_out = gr.Textbox(label="Chosen Strategy", interactive=False)
            lat_out = gr.Number(label="Latency (ms)", precision=1, interactive=False)

    run_btn.click(
        vqa_demo_fn,
        inputs=[img_in, q_in, source_in, dist_in],
        outputs=[ans_out, strat_out, lat_out],
        api_name="predict",
    )

if __name__ == "__main__":
    port = int(os.getenv("PORT", "7860"))
    demo.queue()  #  no concurrency_count
    demo.launch(server_name="0.0.0.0", server_port=port, share=False, show_error=True)