File size: 4,564 Bytes
6fa17cb
31fd9d9
 
 
 
 
 
6fa17cb
31fd9d9
 
 
 
 
 
6fa17cb
31fd9d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import time
import fitz  # PyMuPDF for PDF support
import io

# Define the models you want to compare
MODELS = {
    "Pixtral-12B": "mistralai/Pixtral-12B-2409",
    "InternVL-2.5": "OpenGVLab/InternVL2_5-Chat",
    "Aria-7B": "Aria-7B"  # Replace with actual model ID when public
}

MODEL_CACHE = {}

# Load models and processors (lazy loading for faster startup)
def load_model(model_id):
    if model_id not in MODEL_CACHE:
        processor = AutoProcessor.from_pretrained(model_id)
        model = AutoModelForVision2Seq.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
        MODEL_CACHE[model_id] = (processor, model)
    return MODEL_CACHE[model_id]


def convert_pdf_to_image(pdf_bytes):
    try:
        pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf")
        page = pdf_doc.load_page(0)  # first page only
        pix = page.get_pixmap(dpi=150)
        image_bytes = pix.tobytes("png")
        image = Image.open(io.BytesIO(image_bytes))
        return image
    except Exception as e:
        raise ValueError(f"Failed to convert PDF: {e}")


def compare_models(file, prompt):
    results = {}

    if file is None or not prompt:
        return {name: "Please provide both image/PDF and prompt." for name in MODELS}, None

    # Determine input type (PDF or image)
    if isinstance(file, str):
        image = Image.open(file)
    else:
        file_bytes = file.read() if hasattr(file, 'read') else file
        if file.name.endswith('.pdf'):
            image = convert_pdf_to_image(file_bytes)
        else:
            image = Image.open(io.BytesIO(file_bytes))

    image.thumbnail((512, 512))  # optimize

    latency_data = {}

    for name, model_id in MODELS.items():
        try:
            processor, model = load_model(model_id)
            start = time.time()

            inputs = processor(prompt, image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
            outputs = model.generate(**inputs, max_new_tokens=128)
            text = processor.decode(outputs[0], skip_special_tokens=True)

            elapsed = time.time() - start
            results[name] = f"🧠 {text}\n\n⏱️ {elapsed:.2f}s"
            latency_data[name] = elapsed

        except Exception as e:
            results[name] = f"❌ Error: {str(e)}"
            latency_data[name] = 0

    # Return results and latency chart data
    return [results.get(name, "Model not loaded.") for name in MODELS], latency_data


def plot_latency(latency_data):
    if not latency_data:
        return None
    import matplotlib.pyplot as plt
    plt.figure(figsize=(6, 3))
    plt.bar(latency_data.keys(), latency_data.values())
    plt.title("Model Inference Latency (s)")
    plt.ylabel("Seconds")
    plt.tight_layout()
    return plt


def build_ui():
    with gr.Blocks(title="Multimodal Model Comparator") as demo:
        gr.Markdown("""
        # 🤖 Multimodal Model Comparator
        Upload an **image or PDF document** and enter a question. 
        The app compares outputs from **Pixtral-12B**, **InternVL-2.5**, and **Aria-7B** side-by-side.
        
        _Licenses: Apache 2.0 / MIT — safe for research and demo use._
        """)

        with gr.Row():
            file_input = gr.File(label="Upload Image or PDF", file_types=[".png", ".jpg", ".jpeg", ".pdf"])
            prompt_input = gr.Textbox(label="Prompt", placeholder="Ask something about the image or PDF...")

        with gr.Row():
            pixtral_out = gr.Textbox(label="Pixtral Output")
            internvl_out = gr.Textbox(label="InternVL Output")
            aria_out = gr.Textbox(label="Aria Output")

        latency_plot = gr.Plot(label="Latency Comparison")

        def process(file, prompt):
            outputs, latency_data = compare_models(file, prompt)
            plot = plot_latency(latency_data)
            return outputs[0], outputs[1], outputs[2], plot

        run_button = gr.Button("Run Comparison")
        run_button.click(fn=process, inputs=[file_input, prompt_input], outputs=[pixtral_out, internvl_out, aria_out, latency_plot])

        gr.Examples(
            examples=[
                ["sample_image.jpg", "What is shown in this picture?"],
                ["chart_example.png", "Describe the trend in this chart."],
            ],
            inputs=[file_input, prompt_input]
        )

    return demo


if __name__ == "__main__":
    demo = build_ui()
    demo.launch()