Spaces:
Sleeping
Sleeping
File size: 4,845 Bytes
6fa17cb 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 6fa17cb 0f38fdf 31fd9d9 0f38fdf 31fd9d9 6fa17cb 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 0f38fdf 31fd9d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModel
from PIL import Image
import requests
import time
import io
import fitz # PyMuPDF for PDF support
import matplotlib.pyplot as plt
# Define model repository IDs
MODELS = {
"Pixtral-12B": "mistralai/Pixtral-12B-2409",
"InternVL-3.5": "OpenGVLab/InternVL3_5-241B-A28B",
"Aria-7B": "Aria-7B" # Replace with actual model ID when public
}
MODEL_CACHE = {}
def load_model(model_id):
if model_id not in MODEL_CACHE:
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
MODEL_CACHE[model_id] = (processor, model)
return MODEL_CACHE[model_id]
def convert_pdf_to_image(pdf_bytes):
pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf")
page = pdf_doc.load_page(0)
pix = page.get_pixmap(dpi=150)
image_bytes = pix.tobytes("png")
image = Image.open(io.BytesIO(image_bytes))
return image
def load_image_from_url(url):
response = requests.get(url)
if response.status_code != 200:
raise ValueError(f"Failed to load image from {url}")
return Image.open(io.BytesIO(response.content))
def compare_models(input_url, prompt):
if not input_url or not prompt:
return {name: "Please provide both image/PDF URL and prompt." for name in MODELS}, None
# Load image or PDF from URL
if input_url.lower().endswith('.pdf'):
pdf_data = requests.get(input_url).content
image = convert_pdf_to_image(pdf_data)
else:
image = load_image_from_url(input_url)
image.thumbnail((512, 512))
latency_data = {}
results = {}
for name, model_id in MODELS.items():
try:
processor, model = load_model(model_id)
start = time.time()
if hasattr(model, 'chat'):
text = model.chat(processor.tokenizer, image=image, query=prompt)
else:
inputs = processor(prompt, image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
outputs = model.generate(**inputs, max_new_tokens=128)
text = processor.decode(outputs[0], skip_special_tokens=True)
elapsed = time.time() - start
results[name] = f"π§ {text}\n\nβ±οΈ {elapsed:.2f}s"
latency_data[name] = elapsed
except Exception as e:
results[name] = f"β Error: {str(e)}"
latency_data[name] = 0
return [results.get(name, "Model not loaded.") for name in MODELS], latency_data
def plot_latency(latency_data):
if not latency_data:
return None
plt.figure(figsize=(6, 3))
plt.bar(latency_data.keys(), latency_data.values())
plt.title("Model Inference Latency (s)")
plt.ylabel("Seconds")
plt.tight_layout()
return plt
def build_ui():
with gr.Blocks(title="Multimodal Model Comparator (Online Images)") as demo:
gr.Markdown("""
# π Multimodal Model Comparator (Online Images)
Enter a **URL** for an image or PDF (must be accessible via HTTPS) and provide a question.
The app compares outputs from **Pixtral-12B**, **InternVL-3.5**, and **Aria-7B** side-by-side.
_Licenses: Apache 2.0 / MIT β safe for research and demo use._
""")
with gr.Row():
url_input = gr.Textbox(label="Image or PDF URL", placeholder="https://example.com/sample.jpg")
prompt_input = gr.Textbox(label="Prompt", placeholder="Ask something about the image or PDF...")
with gr.Row():
pixtral_out = gr.Textbox(label="Pixtral Output")
internvl_out = gr.Textbox(label="InternVL Output")
aria_out = gr.Textbox(label="Aria Output")
latency_plot = gr.Plot(label="Latency Comparison")
def process(input_url, prompt):
outputs, latency_data = compare_models(input_url, prompt)
plot = plot_latency(latency_data)
return outputs[0], outputs[1], outputs[2], plot
run_button = gr.Button("Run Comparison")
run_button.click(fn=process, inputs=[url_input, prompt_input], outputs=[pixtral_out, internvl_out, aria_out, latency_plot])
gr.Examples(
examples=[
["https://upload.wikimedia.org/wikipedia/commons/9/99/Unofficial_2023_G20_Logo.png", "Describe this image."],
["https://upload.wikimedia.org/wikipedia/commons/3/3f/Fronalpstock_big.jpg", "What mountain scene is this?"],
["https://arxiv.org/pdf/1706.03762.pdf", "What is this paper about?"],
],
inputs=[url_input, prompt_input]
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch() |