Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime as ort | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| import os, requests | |
| from sklearn.decomposition import PCA | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| # ββ Download model from your ONNX repo βββββββββββββββββββββββββββββββββββββββ | |
| MODEL_PATH = "eupe_convnext-tiny.onnx" | |
| if not os.path.exists(MODEL_PATH): | |
| print("Downloading model...") | |
| hf_hub_download( | |
| repo_id="rockerritesh/EUPE-ONNX", | |
| filename="eupe_convnext-tiny.onnx", | |
| local_dir=".", | |
| ) | |
| sess = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"]) | |
| print("Model loaded!") | |
| # ββ Preprocessing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def preprocess(img: Image.Image) -> np.ndarray: | |
| img = img.convert("RGB").resize((224, 224)) | |
| x = np.array(img, dtype=np.float32) / 255.0 | |
| x = (x - np.array([0.485,0.456,0.406], dtype=np.float32)) / np.array([0.229,0.224,0.225], dtype=np.float32) | |
| return x.transpose(2,0,1)[None].astype(np.float32) | |
| def get_features(img: Image.Image) -> np.ndarray: | |
| return sess.run(None, {"input": preprocess(img)})[0][0] # (768,) | |
| def cosine(a, b): | |
| return float(np.dot(a,b) / (np.linalg.norm(a)*np.linalg.norm(b)+1e-8)) | |
| # ββ Tab 1 : Single image analysis βββββββββββββββββββββββββββββββββββββββββββββ | |
| def analyze_image(img: Image.Image): | |
| if img is None: | |
| return None, "Please upload an image." | |
| feat = get_features(img) | |
| # Feature bar chart β top 30 dims by absolute value | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4)) | |
| fig.suptitle("EUPE ConvNeXt-Tiny β Feature Analysis", fontsize=13, fontweight="bold") | |
| top_idx = np.argsort(np.abs(feat))[-30:][::-1] | |
| colors = cm.RdYlGn((feat[top_idx] - feat[top_idx].min()) / | |
| (feat[top_idx].max() - feat[top_idx].min() + 1e-8)) | |
| axes[0].barh(range(30), feat[top_idx], color=colors) | |
| axes[0].set_yticks(range(30)) | |
| axes[0].set_yticklabels([f"dim {i}" for i in top_idx], fontsize=7) | |
| axes[0].set_xlabel("Activation value") | |
| axes[0].set_title("Top 30 Active Dimensions") | |
| axes[0].axvline(0, color="black", linewidth=0.8) | |
| # Full embedding heatmap (reshape to 32Γ24) | |
| hm = feat[:768].reshape(32, 24) | |
| im = axes[1].imshow(hm, cmap="coolwarm", aspect="auto") | |
| axes[1].set_title("Full Embedding Heatmap (768-dim)") | |
| axes[1].set_xlabel("Dim group"); axes[1].set_ylabel("Dim group") | |
| plt.colorbar(im, ax=axes[1]) | |
| plt.tight_layout() | |
| fig.savefig("/tmp/analysis.png", dpi=120, bbox_inches="tight") | |
| plt.close() | |
| stats = ( | |
| f"**Embedding dimension:** 768 \n" | |
| f"**Mean:** {feat.mean():.4f} \n" | |
| f"**Std:** {feat.std():.4f} \n" | |
| f"**Min:** {feat.min():.4f} \n" | |
| f"**Max:** {feat.max():.4f} \n" | |
| f"**L2 norm:** {np.linalg.norm(feat):.4f}" | |
| ) | |
| return Image.open("/tmp/analysis.png"), stats | |
| # ββ Tab 2 : Image similarity βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compare_images(img1: Image.Image, img2: Image.Image): | |
| if img1 is None or img2 is None: | |
| return None, "Please upload both images." | |
| f1 = get_features(img1) | |
| f2 = get_features(img2) | |
| sim = cosine(f1, f2) | |
| if sim > 0.90: label, color = "Very Similar π’", "green" | |
| elif sim > 0.70: label, color = "Similar π‘", "goldenrod" | |
| elif sim > 0.40: label, color = "Somewhat Similar π ", "orange" | |
| else: label, color = "Different π΄", "red" | |
| fig, axes = plt.subplots(1, 3, figsize=(13, 4)) | |
| fig.suptitle(f"Similarity: {sim:.4f} β {label}", fontsize=13, | |
| fontweight="bold", color=color) | |
| axes[0].imshow(img1.resize((224,224))); axes[0].set_title("Image 1"); axes[0].axis("off") | |
| axes[1].imshow(img2.resize((224,224))); axes[1].set_title("Image 2"); axes[1].axis("off") | |
| # Side-by-side embedding comparison | |
| axes[2].plot(f1, alpha=0.7, label="Image 1", color="steelblue", linewidth=0.8) | |
| axes[2].plot(f2, alpha=0.7, label="Image 2", color="tomato", linewidth=0.8) | |
| axes[2].set_title("Embedding Comparison") | |
| axes[2].set_xlabel("Dimension"); axes[2].set_ylabel("Value") | |
| axes[2].legend() | |
| plt.tight_layout() | |
| fig.savefig("/tmp/compare.png", dpi=120, bbox_inches="tight") | |
| plt.close() | |
| info = ( | |
| f"**Cosine Similarity:** {sim:.4f} \n" | |
| f"**Verdict:** {label} \n\n" | |
| f"*0.0 = completely different, 1.0 = identical*" | |
| ) | |
| return Image.open("/tmp/compare.png"), info | |
| # ββ Tab 3 : Multi-image ranking ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def rank_images(query, img1, img2, img3, img4): | |
| if query is None: | |
| return None, "Please upload a query image." | |
| candidates = [(img, f"Image {i+1}") for i, img in | |
| enumerate([img1,img2,img3,img4]) if img is not None] | |
| if not candidates: | |
| return None, "Please upload at least one candidate image." | |
| qf = get_features(query) | |
| sims = [(cosine(qf, get_features(img)), lbl, img) for img, lbl in candidates] | |
| sims.sort(reverse=True) | |
| n = len(sims) + 1 | |
| fig, axes = plt.subplots(1, n, figsize=(3.5*n, 4)) | |
| fig.suptitle("Zero-Shot Image Retrieval", fontsize=13, fontweight="bold") | |
| axes[0].imshow(query.resize((200,200))) | |
| axes[0].set_title("QUERY", color="red", fontweight="bold", fontsize=11) | |
| axes[0].axis("off") | |
| for i, (sim, lbl, img) in enumerate(sims): | |
| c = "green" if sim > 0.7 else "orange" if sim > 0.4 else "red" | |
| axes[i+1].imshow(img.resize((200,200))) | |
| axes[i+1].set_title(f"#{i+1} {lbl}\n{sim:.3f}", color=c, fontsize=10) | |
| axes[i+1].axis("off") | |
| plt.tight_layout() | |
| fig.savefig("/tmp/ranking.png", dpi=120, bbox_inches="tight") | |
| plt.close() | |
| result = "\n".join([f"**#{i+1}** {lbl} β `{sim:.4f}`" | |
| for i, (sim, lbl, _) in enumerate(sims)]) | |
| return Image.open("/tmp/ranking.png"), result | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="EUPE Vision Encoder", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π EUPE Vision Encoder β ConvNeXt-Tiny | |
| **Efficient Universal Perception Encoder** by Meta AI β a single lightweight | |
| vision backbone for diverse tasks. | |
| - Model: `eupe_convnext-tiny.onnx` (111 MB FP32, CPU) | |
| - Embedding: 768-dimensional feature vector per image | |
| - [Paper](https://arxiv.org/abs/2603.22387) Β· [ONNX Models](https://huggingface.co/rockerritesh/EUPE-ONNX) | |
| """) | |
| with gr.Tabs(): | |
| with gr.TabItem("πΌοΈ Analyze Image"): | |
| with gr.Row(): | |
| inp_img = gr.Image(type="pil", label="Upload Image") | |
| btn1 = gr.Button("Analyze", variant="primary") | |
| out_plot1 = gr.Image(label="Feature Analysis") | |
| out_text1 = gr.Markdown() | |
| btn1.click(analyze_image, inputs=inp_img, outputs=[out_plot1, out_text1]) | |
| with gr.TabItem("π Compare Two Images"): | |
| with gr.Row(): | |
| img_a = gr.Image(type="pil", label="Image 1") | |
| img_b = gr.Image(type="pil", label="Image 2") | |
| btn2 = gr.Button("Compare", variant="primary") | |
| out_plot2 = gr.Image(label="Comparison") | |
| out_text2 = gr.Markdown() | |
| btn2.click(compare_images, inputs=[img_a, img_b], | |
| outputs=[out_plot2, out_text2]) | |
| with gr.TabItem("π Image Retrieval"): | |
| gr.Markdown("Upload a **query** image and up to 4 **candidates**. " | |
| "Ranks candidates by similarity to the query.") | |
| with gr.Row(): | |
| q_img = gr.Image(type="pil", label="Query Image") | |
| with gr.Row(): | |
| c1 = gr.Image(type="pil", label="Candidate 1") | |
| c2 = gr.Image(type="pil", label="Candidate 2") | |
| c3 = gr.Image(type="pil", label="Candidate 3") | |
| c4 = gr.Image(type="pil", label="Candidate 4") | |
| btn3 = gr.Button("Rank", variant="primary") | |
| out_plot3 = gr.Image(label="Ranking") | |
| out_text3 = gr.Markdown() | |
| btn3.click(rank_images, inputs=[q_img,c1,c2,c3,c4], | |
| outputs=[out_plot3, out_text3]) | |
| demo.launch() | |