| import gradio as gr |
| import os |
| import torch |
| import faiss |
| import pandas as pd |
| import numpy as np |
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
| import torchvision.transforms as tfm |
| import torchvision.transforms.v2 as v2 |
| import requests |
| from io import BytesIO |
| import urllib.parse |
| from sklearn.decomposition import PCA |
| import base64 |
|
|
| |
| from dinov2 import DINOv2FeatureExtractor |
| from dinov3 import DINOv3FeatureExtractor |
|
|
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| HF_USERNAME = "pawlo2013" |
| DEFAULT_DATASET = "Cars196" |
| DEFAULT_VERSION = "3" |
| DEFAULT_SIZE = "b" |
|
|
| class GlobalState: |
| model = None |
| index = None |
| mapping_df = None |
| transform = None |
| current_config = {} |
| current_results_text = "" |
| pca_model = None |
|
|
| state = GlobalState() |
|
|
| |
| |
| |
| def extract_class_name(url): |
| try: |
| decoded_url = urllib.parse.unquote(url) |
| parts = decoded_url.split('/') |
| if len(parts) >= 2: |
| class_folder = parts[-2] |
| return class_folder.replace('_', ' ') |
| return "Unknown" |
| except Exception: |
| return "N/A" |
|
|
| def get_transforms(dino_version): |
| width, height = (518, 518) if dino_version == "2" else (512, 512) |
| return tfm.Compose([ |
| v2.RGB(), |
| tfm.Resize(size=(width, height), antialias=True), |
| tfm.ToTensor(), |
| tfm.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ]) |
|
|
| def construct_image_url(file_path, dataset_name): |
| image_repo_id = f"{HF_USERNAME}/{dataset_name}" |
| clean_path = file_path.replace("\\", "/") |
| |
| prefix = f"data/{dataset_name}/" |
| if clean_path.startswith(prefix): |
| clean_path = clean_path.replace(prefix, "", 1) |
| elif clean_path.startswith("data/"): |
| clean_path = clean_path.replace("data/", "", 1) |
| |
| if dataset_name == "StanfordOnlineProducts": |
| if not clean_path.startswith("Stanford_Online_Products"): |
| clean_path = f"Stanford_Online_Products/{clean_path}" |
| |
| return f"https://huggingface.co/datasets/{image_repo_id}/resolve/main/{clean_path}" |
|
|
| def construct_repo_id(dataset, version, size, finetune): |
| run_name = ( |
| f"{dataset}_dino{version}" |
| f"{'_finetune_' if finetune else ''}" |
| f"{size}" |
| ) |
| index_repo_id = f"{HF_USERNAME}/{run_name}" |
| model_repo_id = f"{HF_USERNAME}/{run_name}-model" |
| return index_repo_id, model_repo_id |
|
|
| def load_resources(dataset, dino_version, dino_size, is_finetuned): |
| config_key = f"{dataset}_{dino_version}_{dino_size}_{is_finetuned}" |
| if state.current_config.get("key") == config_key: |
| return (f"Resources already loaded for {config_key}!", state.current_results_text) |
| |
| index_repo, model_repo = construct_repo_id(dataset, dino_version, dino_size, is_finetuned) |
| results_display = "No results available." |
| |
| try: |
| try: |
| results_path = hf_hub_download(repo_id=index_repo, filename="results.txt", repo_type="dataset") |
| with open(results_path, 'r', encoding='utf-8') as f: |
| raw_text = f.read() |
| results_display = f"```text\n{raw_text}\n```" |
| except Exception: |
| results_display = "β οΈ `results.txt` not found." |
| |
| index_path = hf_hub_download(repo_id=index_repo, filename="faiss_index.bin", repo_type="dataset") |
| csv_path = hf_hub_download(repo_id=index_repo, filename="faiss_index_mapping.csv", repo_type="dataset") |
| |
| state.index = faiss.read_index(index_path) |
| state.mapping_df = pd.read_csv(csv_path) |
| state.mapping_df['image_url'] = state.mapping_df['file_path'].apply( |
| lambda x: construct_image_url(x, dataset) |
| ) |
| |
| if dino_version == "3": |
| model_name_map = { |
| "s": "facebook/dinov3-vits16-pretrain-lvd1689m", |
| "b": "facebook/dinov3-vitb16-pretrain-lvd1689m", |
| "l": "facebook/dinov3-vitl16-pretrain-lvd1689m" |
| } |
| state.model = DINOv3FeatureExtractor(model_type=model_name_map[dino_size]) |
| |
| if is_finetuned: |
| weights_path = hf_hub_download(repo_id=model_repo, filename="best_model.pth", repo_type="model") |
| state.model.load_state_dict(torch.load(weights_path, map_location=DEVICE, weights_only=True)) |
| |
| state.model.to(DEVICE) |
| state.model.eval() |
| state.transform = get_transforms(dino_version) |
| state.current_config = {"key": config_key} |
| state.current_results_text = results_display |
| |
| return f"β
Successfully loaded {dataset}", results_display |
| |
| except Exception as e: |
| return f"β Error: {str(e)}", "Error loading stats." |
|
|
| def pil_to_base64(pil_img): |
| """Converts a PIL Image to a base64 data URI string.""" |
| img_buffer = BytesIO() |
| pil_img = pil_img.convert("RGB") |
| pil_img.save(img_buffer, format="JPEG") |
| byte_data = img_buffer.getvalue() |
| base64_str = base64.b64encode(byte_data).decode("utf-8") |
| return f"data:image/jpeg;base64,{base64_str}" |
|
|
| def fetch_image_from_url(url): |
| try: |
| if url.startswith("data:image"): |
| header, encoded = url.split(",", 1) |
| data = base64.b64decode(encoded) |
| return Image.open(BytesIO(data)).convert("RGB") |
| |
| headers = {'User-Agent': 'Mozilla/5.0'} |
| response = requests.get(url, headers=headers, timeout=5) |
| response.raise_for_status() |
| return Image.open(BytesIO(response.content)).convert("RGB") |
| except Exception: |
| return Image.new("RGB", (224, 224), color="red") |
|
|
| def get_example_images(num_examples=10): |
| if state.mapping_df is None: return [] |
| test_df = state.mapping_df[state.mapping_df['split'] == 'test'] |
| if test_df.empty: test_df = state.mapping_df |
| sample = test_df.sample(n=min(len(test_df), num_examples)) |
| return [(fetch_image_from_url(row['image_url']), row['image_url']) for _, row in sample.iterrows()] |
|
|
| def process_image(image_input, k_neighbors): |
| if state.model is None or state.index is None: |
| return [], "β οΈ Please wait for model to load..." |
| |
| try: |
| k = int(k_neighbors) |
| if isinstance(image_input, str): |
| query_img = fetch_image_from_url(image_input) |
| else: |
| query_img = Image.fromarray(image_input) if isinstance(image_input, np.ndarray) else image_input |
| |
| img_tensor = state.transform(query_img).unsqueeze(0).to(DEVICE) |
| |
| with torch.no_grad(): |
| embedding = state.model(img_tensor).cpu().numpy().astype(np.float32) |
| |
| faiss.normalize_L2(embedding) |
| distances, indices = state.index.search(embedding, k) |
| |
| results = [] |
| for dist, idx in zip(distances[0], indices[0]): |
| if idx < 0 or idx >= len(state.mapping_df): continue |
| row = state.mapping_df.iloc[idx] |
| url = row['image_url'] |
| |
| class_name = extract_class_name(url) |
| caption = f"Class: {class_name}\nSim: {dist:.3f}" |
| |
| res_img = fetch_image_from_url(url) |
| results.append((res_img, caption)) |
| |
| return results, f"β
Found {k} matches." |
| except Exception as e: |
| return [], f"β Search failed: {str(e)}" |
|
|
| |
| |
| |
| def get_faiss_samples(index_path, dataset_name, num_samples): |
| """ |
| API endpoint function. Takes a FAISS index path, dataset name, and number of samples. |
| Returns file path, class name, image URL (standard string), and 3D PCA coordinates. |
| """ |
| try: |
| if index_path.endswith('.bin'): |
| csv_path = index_path.replace('.bin', '_mapping.csv') |
| elif os.path.isdir(index_path): |
| csv_path = os.path.join(index_path, 'faiss_index_mapping.csv') |
| else: |
| csv_path = index_path |
| |
| if not os.path.exists(csv_path): |
| if state.mapping_df is not None: |
| df = state.mapping_df |
| else: |
| return {"error": f"Mapping file not found at {csv_path} and no active memory state."} |
| else: |
| df = pd.read_csv(csv_path) |
| |
| if state.index is not None: |
| faiss_idx = state.index |
| else: |
| if not os.path.exists(index_path): |
| return {"error": f"FAISS index not found at {index_path} and not in memory."} |
| faiss_idx = faiss.read_index(index_path) |
| |
| try: |
| faiss_idx.reconstruct(0) |
| except RuntimeError: |
| try: |
| faiss_idx.make_direct_map() |
| except AttributeError: |
| pass |
| |
| n = int(num_samples) |
| sample_df = df.sample(n=min(n, len(df))) |
| |
| vectors = [] |
| valid_indices = [] |
| |
| for orig_idx, row in sample_df.iterrows(): |
| try: |
| vec = faiss_idx.reconstruct(int(orig_idx)) |
| vectors.append(vec) |
| valid_indices.append(orig_idx) |
| except Exception as e: |
| continue |
| |
| vectors = np.array(vectors) |
| if len(vectors) >= 3: |
| pca = PCA(n_components=3) |
| pca_coords = pca.fit_transform(vectors) |
| state.pca_model = pca |
| else: |
| pca_coords = np.zeros((len(vectors), 3)) |
| state.pca_model = None |
| |
| results = [] |
| for i, orig_idx in enumerate(valid_indices): |
| row = sample_df.loc[orig_idx] |
| file_path = str(row.get('file_path', '')) |
| class_name = extract_class_name(file_path) |
| |
| if 'image_url' in row and pd.notna(row['image_url']): |
| img_url = row['image_url'] |
| else: |
| img_url = construct_image_url(file_path, dataset_name) |
| |
| clean_path = file_path.replace('\\', '/') |
| |
| results.append({ |
| "file_path": clean_path, |
| "class_name": class_name, |
| "image_url": img_url, |
| "pca_3d": pca_coords[i].tolist() |
| }) |
| |
| return {"samples": results} |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| def embed_image_api(image_input, index_path, dataset_name, skip_pca=False): |
| """ |
| API endpoint function. Embeds the Image using the model. |
| If skip_pca is False, projects it into 3D using the cached PCA (or calculates it via index fallback). |
| Returns it with the raw_vector and Base64 image. |
| """ |
| if state.model is None: |
| return {"error": "Model not loaded. Please trigger 'Re-Load Resources' via UI or API first."} |
|
|
| try: |
| if isinstance(image_input, str): |
| query_img = fetch_image_from_url(image_input) |
| else: |
| query_img = Image.fromarray(image_input) if isinstance(image_input, np.ndarray) else image_input |
|
|
| img_tensor = state.transform(query_img).unsqueeze(0).to(DEVICE) |
|
|
| with torch.no_grad(): |
| embedding = state.model(img_tensor).cpu().numpy().astype(np.float32) |
| faiss.normalize_L2(embedding) |
|
|
| if skip_pca: |
| pca_3d = [0.0, 0.0, 0.0] |
| else: |
| |
| if state.pca_model is None and index_path: |
| faiss_idx = None |
| if state.index is not None: |
| faiss_idx = state.index |
| elif os.path.exists(index_path): |
| faiss_idx = faiss.read_index(index_path) |
|
|
| if faiss_idx is not None: |
| try: |
| total_vectors = faiss_idx.ntotal |
| sample_size = min(250, total_vectors) |
| np.random.seed(42) |
| sample_ids = np.random.choice(total_vectors, sample_size, replace=False) |
| |
| fallback_vectors = [] |
| for orig_idx in sample_ids: |
| try: |
| vec = faiss_idx.reconstruct(int(orig_idx)) |
| fallback_vectors.append(vec) |
| except Exception: |
| continue |
| |
| fallback_vectors = np.array(fallback_vectors) |
| if len(fallback_vectors) >= 3: |
| pca = PCA(n_components=3) |
| pca.fit(fallback_vectors) |
| state.pca_model = pca |
| except Exception: |
| pass |
|
|
| |
| if state.pca_model is not None: |
| pca_3d = state.pca_model.transform(embedding)[0].tolist() |
| else: |
| pca_3d = [0.0, 0.0, 0.0] |
|
|
| b64_img = pil_to_base64(query_img) |
| |
| results = [{ |
| "file_path": "uploaded_query_image", |
| "class_name": "Query", |
| "image_url": b64_img, |
| "pca_3d": pca_3d, |
| "raw_vector": embedding[0].tolist() |
| }] |
| |
| return {"samples": results} |
|
|
| except Exception as e: |
| return {"error": str(e)} |
|
|
| |
| |
| |
| def refresh_examples_wrapper(): |
| return get_example_images(10) |
|
|
| def on_select_example(evt: gr.SelectData, gallery_data, k): |
| if not gallery_data: return |
| url = gallery_data[evt.index][1] |
| return process_image(url, k) |
|
|
| with gr.Blocks(title="DINO Image Retrieval") as demo: |
| gr.Markdown("# π¦ DINOv3 Image Retrieval System") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| with gr.Group(): |
| gr.Markdown("### βοΈ Configuration") |
| inp_dataset = gr.Dropdown(label="Dataset", choices=["Cars196", "CUB", "StanfordOnlineProducts"], value=DEFAULT_DATASET) |
| with gr.Row(): |
| inp_ver = gr.Dropdown(label="Version", choices=["3"], value=DEFAULT_VERSION) |
| inp_size = gr.Dropdown(label="Size", choices=["s", "b"], value=DEFAULT_SIZE) |
| |
| inp_finetune = gr.Checkbox(label="Finetuned?", value=False) |
| inp_k = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Top-K Matches") |
| |
| btn_load = gr.Button("Re-Load Resources", variant="secondary") |
| out_status = gr.Textbox(label="Status", value="Initializing...", interactive=False) |
| |
| gr.Markdown("### π Performance Stats") |
| out_results = gr.Markdown(value="Stats will appear here.") |
| |
| with gr.Column(scale=2): |
| with gr.Tabs(): |
| with gr.TabItem("Select Example"): |
| btn_refresh_ex = gr.Button("π Refresh Examples") |
| ex_gallery = gr.Gallery(label="Examples", columns=5, height="auto") |
| |
| with gr.TabItem("Upload Image"): |
| inp_img_upload = gr.Image(type="pil", label="Upload Query") |
| btn_search_upload = gr.Button("π Search", variant="primary") |
| |
| gr.Markdown("### Matches") |
| out_gallery = gr.Gallery(label="Results", columns=5, height="auto") |
|
|
| |
| |
| |
| api_index_path = gr.Textbox(visible=False) |
| api_dataset_name = gr.Textbox(visible=False) |
| api_num_samples = gr.Number(visible=False) |
| api_samples_output = gr.JSON(visible=False) |
| api_samples_btn = gr.Button(visible=False) |
| |
| api_samples_btn.click( |
| fn=get_faiss_samples, |
| inputs=[api_index_path, api_dataset_name, api_num_samples], |
| outputs=[api_samples_output], |
| api_name="get_samples" |
| ) |
|
|
| |
| api_embed_img_input = gr.Image(visible=False) |
| api_skip_pca_input = gr.Checkbox(value=False, visible=False) |
| api_embed_output = gr.JSON(visible=False) |
| api_embed_btn = gr.Button(visible=False) |
|
|
| api_embed_btn.click( |
| fn=embed_image_api, |
| inputs=[api_embed_img_input, api_index_path, api_dataset_name, api_skip_pca_input], |
| outputs=[api_embed_output], |
| api_name="embed" |
| ) |
|
|
| |
| btn_load.click(load_resources, [inp_dataset, inp_ver, inp_size, inp_finetune], [out_status, out_results]).then(refresh_examples_wrapper, outputs=[ex_gallery]) |
| btn_search_upload.click(process_image, [inp_img_upload, inp_k], [out_gallery, out_status]) |
| btn_refresh_ex.click(refresh_examples_wrapper, outputs=[ex_gallery]) |
| ex_gallery.select(on_select_example, [ex_gallery, inp_k], [out_gallery, out_status]) |
| |
| demo.load(load_resources, [inp_dataset, inp_ver, inp_size, inp_finetune], [out_status, out_results], queue=False).then(refresh_examples_wrapper, outputs=[ex_gallery]) |
|
|
| if __name__ == "__main__": |
| demo.launch(theme=gr.themes.Soft()) |