File size: 4,845 Bytes
6fa17cb
31fd9d9
0f38fdf
31fd9d9
0f38fdf
31fd9d9
 
0f38fdf
 
6fa17cb
0f38fdf
31fd9d9
 
0f38fdf
31fd9d9
 
6fa17cb
31fd9d9
 
 
 
0f38fdf
 
31fd9d9
 
 
 
0f38fdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31fd9d9
0f38fdf
31fd9d9
0f38fdf
31fd9d9
0f38fdf
31fd9d9
 
 
 
 
0f38fdf
 
 
 
 
 
31fd9d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f38fdf
31fd9d9
0f38fdf
 
 
 
31fd9d9
 
 
 
0f38fdf
31fd9d9
 
 
 
 
 
 
 
 
0f38fdf
 
31fd9d9
 
 
 
0f38fdf
31fd9d9
 
 
0f38fdf
 
 
31fd9d9
0f38fdf
31fd9d9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModel
from PIL import Image
import requests
import time
import io
import fitz  # PyMuPDF for PDF support
import matplotlib.pyplot as plt

# Define model repository IDs
MODELS = {
    "Pixtral-12B": "mistralai/Pixtral-12B-2409",
    "InternVL-3.5": "OpenGVLab/InternVL3_5-241B-A28B",
    "Aria-7B": "Aria-7B"  # Replace with actual model ID when public
}

MODEL_CACHE = {}

def load_model(model_id):
    if model_id not in MODEL_CACHE:
        processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
        model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
        MODEL_CACHE[model_id] = (processor, model)
    return MODEL_CACHE[model_id]

def convert_pdf_to_image(pdf_bytes):
    pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf")
    page = pdf_doc.load_page(0)
    pix = page.get_pixmap(dpi=150)
    image_bytes = pix.tobytes("png")
    image = Image.open(io.BytesIO(image_bytes))
    return image

def load_image_from_url(url):
    response = requests.get(url)
    if response.status_code != 200:
        raise ValueError(f"Failed to load image from {url}")
    return Image.open(io.BytesIO(response.content))

def compare_models(input_url, prompt):
    if not input_url or not prompt:
        return {name: "Please provide both image/PDF URL and prompt." for name in MODELS}, None

    # Load image or PDF from URL
    if input_url.lower().endswith('.pdf'):
        pdf_data = requests.get(input_url).content
        image = convert_pdf_to_image(pdf_data)
    else:
        image = load_image_from_url(input_url)

    image.thumbnail((512, 512))
    latency_data = {}
    results = {}

    for name, model_id in MODELS.items():
        try:
            processor, model = load_model(model_id)
            start = time.time()
            if hasattr(model, 'chat'):
                text = model.chat(processor.tokenizer, image=image, query=prompt)
            else:
                inputs = processor(prompt, image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
                outputs = model.generate(**inputs, max_new_tokens=128)
                text = processor.decode(outputs[0], skip_special_tokens=True)
            elapsed = time.time() - start
            results[name] = f"🧠 {text}\n\n⏱️ {elapsed:.2f}s"
            latency_data[name] = elapsed
        except Exception as e:
            results[name] = f"❌ Error: {str(e)}"
            latency_data[name] = 0

    return [results.get(name, "Model not loaded.") for name in MODELS], latency_data

def plot_latency(latency_data):
    if not latency_data:
        return None
    plt.figure(figsize=(6, 3))
    plt.bar(latency_data.keys(), latency_data.values())
    plt.title("Model Inference Latency (s)")
    plt.ylabel("Seconds")
    plt.tight_layout()
    return plt

def build_ui():
    with gr.Blocks(title="Multimodal Model Comparator (Online Images)") as demo:
        gr.Markdown("""
        # 🌐 Multimodal Model Comparator (Online Images)
        Enter a **URL** for an image or PDF (must be accessible via HTTPS) and provide a question. 
        The app compares outputs from **Pixtral-12B**, **InternVL-3.5**, and **Aria-7B** side-by-side.

        _Licenses: Apache 2.0 / MIT β€” safe for research and demo use._
        """)

        with gr.Row():
            url_input = gr.Textbox(label="Image or PDF URL", placeholder="https://example.com/sample.jpg")
            prompt_input = gr.Textbox(label="Prompt", placeholder="Ask something about the image or PDF...")

        with gr.Row():
            pixtral_out = gr.Textbox(label="Pixtral Output")
            internvl_out = gr.Textbox(label="InternVL Output")
            aria_out = gr.Textbox(label="Aria Output")

        latency_plot = gr.Plot(label="Latency Comparison")

        def process(input_url, prompt):
            outputs, latency_data = compare_models(input_url, prompt)
            plot = plot_latency(latency_data)
            return outputs[0], outputs[1], outputs[2], plot

        run_button = gr.Button("Run Comparison")
        run_button.click(fn=process, inputs=[url_input, prompt_input], outputs=[pixtral_out, internvl_out, aria_out, latency_plot])

        gr.Examples(
            examples=[
                ["https://upload.wikimedia.org/wikipedia/commons/9/99/Unofficial_2023_G20_Logo.png", "Describe this image."],
                ["https://upload.wikimedia.org/wikipedia/commons/3/3f/Fronalpstock_big.jpg", "What mountain scene is this?"],
                ["https://arxiv.org/pdf/1706.03762.pdf", "What is this paper about?"],
            ],
            inputs=[url_input, prompt_input]
        )

    return demo

if __name__ == "__main__":
    demo = build_ui()
    demo.launch()