Sunil Sarolkar
updated image references
0f38fdf
raw
history blame
4.85 kB
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModel
from PIL import Image
import requests
import time
import io
import fitz # PyMuPDF for PDF support
import matplotlib.pyplot as plt
# Define model repository IDs
MODELS = {
"Pixtral-12B": "mistralai/Pixtral-12B-2409",
"InternVL-3.5": "OpenGVLab/InternVL3_5-241B-A28B",
"Aria-7B": "Aria-7B" # Replace with actual model ID when public
}
MODEL_CACHE = {}
def load_model(model_id):
if model_id not in MODEL_CACHE:
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
MODEL_CACHE[model_id] = (processor, model)
return MODEL_CACHE[model_id]
def convert_pdf_to_image(pdf_bytes):
pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf")
page = pdf_doc.load_page(0)
pix = page.get_pixmap(dpi=150)
image_bytes = pix.tobytes("png")
image = Image.open(io.BytesIO(image_bytes))
return image
def load_image_from_url(url):
response = requests.get(url)
if response.status_code != 200:
raise ValueError(f"Failed to load image from {url}")
return Image.open(io.BytesIO(response.content))
def compare_models(input_url, prompt):
if not input_url or not prompt:
return {name: "Please provide both image/PDF URL and prompt." for name in MODELS}, None
# Load image or PDF from URL
if input_url.lower().endswith('.pdf'):
pdf_data = requests.get(input_url).content
image = convert_pdf_to_image(pdf_data)
else:
image = load_image_from_url(input_url)
image.thumbnail((512, 512))
latency_data = {}
results = {}
for name, model_id in MODELS.items():
try:
processor, model = load_model(model_id)
start = time.time()
if hasattr(model, 'chat'):
text = model.chat(processor.tokenizer, image=image, query=prompt)
else:
inputs = processor(prompt, image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
outputs = model.generate(**inputs, max_new_tokens=128)
text = processor.decode(outputs[0], skip_special_tokens=True)
elapsed = time.time() - start
results[name] = f"🧠 {text}\n\n⏱️ {elapsed:.2f}s"
latency_data[name] = elapsed
except Exception as e:
results[name] = f"❌ Error: {str(e)}"
latency_data[name] = 0
return [results.get(name, "Model not loaded.") for name in MODELS], latency_data
def plot_latency(latency_data):
if not latency_data:
return None
plt.figure(figsize=(6, 3))
plt.bar(latency_data.keys(), latency_data.values())
plt.title("Model Inference Latency (s)")
plt.ylabel("Seconds")
plt.tight_layout()
return plt
def build_ui():
with gr.Blocks(title="Multimodal Model Comparator (Online Images)") as demo:
gr.Markdown("""
# 🌐 Multimodal Model Comparator (Online Images)
Enter a **URL** for an image or PDF (must be accessible via HTTPS) and provide a question.
The app compares outputs from **Pixtral-12B**, **InternVL-3.5**, and **Aria-7B** side-by-side.
_Licenses: Apache 2.0 / MIT β€” safe for research and demo use._
""")
with gr.Row():
url_input = gr.Textbox(label="Image or PDF URL", placeholder="https://example.com/sample.jpg")
prompt_input = gr.Textbox(label="Prompt", placeholder="Ask something about the image or PDF...")
with gr.Row():
pixtral_out = gr.Textbox(label="Pixtral Output")
internvl_out = gr.Textbox(label="InternVL Output")
aria_out = gr.Textbox(label="Aria Output")
latency_plot = gr.Plot(label="Latency Comparison")
def process(input_url, prompt):
outputs, latency_data = compare_models(input_url, prompt)
plot = plot_latency(latency_data)
return outputs[0], outputs[1], outputs[2], plot
run_button = gr.Button("Run Comparison")
run_button.click(fn=process, inputs=[url_input, prompt_input], outputs=[pixtral_out, internvl_out, aria_out, latency_plot])
gr.Examples(
examples=[
["https://upload.wikimedia.org/wikipedia/commons/9/99/Unofficial_2023_G20_Logo.png", "Describe this image."],
["https://upload.wikimedia.org/wikipedia/commons/3/3f/Fronalpstock_big.jpg", "What mountain scene is this?"],
["https://arxiv.org/pdf/1706.03762.pdf", "What is this paper about?"],
],
inputs=[url_input, prompt_input]
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch()