Spaces:
Sleeping
Sleeping
File size: 4,564 Bytes
6fa17cb 31fd9d9 6fa17cb 31fd9d9 6fa17cb 31fd9d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import time
import fitz # PyMuPDF for PDF support
import io
# Define the models you want to compare
MODELS = {
"Pixtral-12B": "mistralai/Pixtral-12B-2409",
"InternVL-2.5": "OpenGVLab/InternVL2_5-Chat",
"Aria-7B": "Aria-7B" # Replace with actual model ID when public
}
MODEL_CACHE = {}
# Load models and processors (lazy loading for faster startup)
def load_model(model_id):
if model_id not in MODEL_CACHE:
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
MODEL_CACHE[model_id] = (processor, model)
return MODEL_CACHE[model_id]
def convert_pdf_to_image(pdf_bytes):
try:
pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf")
page = pdf_doc.load_page(0) # first page only
pix = page.get_pixmap(dpi=150)
image_bytes = pix.tobytes("png")
image = Image.open(io.BytesIO(image_bytes))
return image
except Exception as e:
raise ValueError(f"Failed to convert PDF: {e}")
def compare_models(file, prompt):
results = {}
if file is None or not prompt:
return {name: "Please provide both image/PDF and prompt." for name in MODELS}, None
# Determine input type (PDF or image)
if isinstance(file, str):
image = Image.open(file)
else:
file_bytes = file.read() if hasattr(file, 'read') else file
if file.name.endswith('.pdf'):
image = convert_pdf_to_image(file_bytes)
else:
image = Image.open(io.BytesIO(file_bytes))
image.thumbnail((512, 512)) # optimize
latency_data = {}
for name, model_id in MODELS.items():
try:
processor, model = load_model(model_id)
start = time.time()
inputs = processor(prompt, image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
outputs = model.generate(**inputs, max_new_tokens=128)
text = processor.decode(outputs[0], skip_special_tokens=True)
elapsed = time.time() - start
results[name] = f"🧠 {text}\n\n⏱️ {elapsed:.2f}s"
latency_data[name] = elapsed
except Exception as e:
results[name] = f"❌ Error: {str(e)}"
latency_data[name] = 0
# Return results and latency chart data
return [results.get(name, "Model not loaded.") for name in MODELS], latency_data
def plot_latency(latency_data):
if not latency_data:
return None
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 3))
plt.bar(latency_data.keys(), latency_data.values())
plt.title("Model Inference Latency (s)")
plt.ylabel("Seconds")
plt.tight_layout()
return plt
def build_ui():
with gr.Blocks(title="Multimodal Model Comparator") as demo:
gr.Markdown("""
# 🤖 Multimodal Model Comparator
Upload an **image or PDF document** and enter a question.
The app compares outputs from **Pixtral-12B**, **InternVL-2.5**, and **Aria-7B** side-by-side.
_Licenses: Apache 2.0 / MIT — safe for research and demo use._
""")
with gr.Row():
file_input = gr.File(label="Upload Image or PDF", file_types=[".png", ".jpg", ".jpeg", ".pdf"])
prompt_input = gr.Textbox(label="Prompt", placeholder="Ask something about the image or PDF...")
with gr.Row():
pixtral_out = gr.Textbox(label="Pixtral Output")
internvl_out = gr.Textbox(label="InternVL Output")
aria_out = gr.Textbox(label="Aria Output")
latency_plot = gr.Plot(label="Latency Comparison")
def process(file, prompt):
outputs, latency_data = compare_models(file, prompt)
plot = plot_latency(latency_data)
return outputs[0], outputs[1], outputs[2], plot
run_button = gr.Button("Run Comparison")
run_button.click(fn=process, inputs=[file_input, prompt_input], outputs=[pixtral_out, internvl_out, aria_out, latency_plot])
gr.Examples(
examples=[
["sample_image.jpg", "What is shown in this picture?"],
["chart_example.png", "Describe the trend in this chart."],
],
inputs=[file_input, prompt_input]
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch() |