pawlo2013's picture
Update app.py
2785a44 verified
import gradio as gr
import os
import torch
import faiss
import pandas as pd
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import torchvision.transforms as tfm
import torchvision.transforms.v2 as v2
import requests
from io import BytesIO
import urllib.parse
from sklearn.decomposition import PCA
import base64
# --- Import your model definitions ---
from dinov2 import DINOv2FeatureExtractor
from dinov3 import DINOv3FeatureExtractor
# --- Constants & Configuration ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HF_USERNAME = "pawlo2013"
DEFAULT_DATASET = "Cars196"
DEFAULT_VERSION = "3"
DEFAULT_SIZE = "b"
class GlobalState:
model = None
index = None
mapping_df = None
transform = None
current_config = {}
current_results_text = ""
pca_model = None # Cache for the PCA transformation
state = GlobalState()
# ==========================================
# 1. HELPER FUNCTIONS
# ==========================================
def extract_class_name(url):
try:
decoded_url = urllib.parse.unquote(url)
parts = decoded_url.split('/')
if len(parts) >= 2:
class_folder = parts[-2]
return class_folder.replace('_', ' ')
return "Unknown"
except Exception:
return "N/A"
def get_transforms(dino_version):
width, height = (518, 518) if dino_version == "2" else (512, 512)
return tfm.Compose([
v2.RGB(),
tfm.Resize(size=(width, height), antialias=True),
tfm.ToTensor(),
tfm.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
def construct_image_url(file_path, dataset_name):
image_repo_id = f"{HF_USERNAME}/{dataset_name}"
clean_path = file_path.replace("\\", "/")
prefix = f"data/{dataset_name}/"
if clean_path.startswith(prefix):
clean_path = clean_path.replace(prefix, "", 1)
elif clean_path.startswith("data/"):
clean_path = clean_path.replace("data/", "", 1)
if dataset_name == "StanfordOnlineProducts":
if not clean_path.startswith("Stanford_Online_Products"):
clean_path = f"Stanford_Online_Products/{clean_path}"
return f"https://huggingface.co/datasets/{image_repo_id}/resolve/main/{clean_path}"
def construct_repo_id(dataset, version, size, finetune):
run_name = (
f"{dataset}_dino{version}"
f"{'_finetune_' if finetune else ''}"
f"{size}"
)
index_repo_id = f"{HF_USERNAME}/{run_name}"
model_repo_id = f"{HF_USERNAME}/{run_name}-model"
return index_repo_id, model_repo_id
def load_resources(dataset, dino_version, dino_size, is_finetuned):
config_key = f"{dataset}_{dino_version}_{dino_size}_{is_finetuned}"
if state.current_config.get("key") == config_key:
return (f"Resources already loaded for {config_key}!", state.current_results_text)
index_repo, model_repo = construct_repo_id(dataset, dino_version, dino_size, is_finetuned)
results_display = "No results available."
try:
try:
results_path = hf_hub_download(repo_id=index_repo, filename="results.txt", repo_type="dataset")
with open(results_path, 'r', encoding='utf-8') as f:
raw_text = f.read()
results_display = f"```text\n{raw_text}\n```"
except Exception:
results_display = "⚠️ `results.txt` not found."
index_path = hf_hub_download(repo_id=index_repo, filename="faiss_index.bin", repo_type="dataset")
csv_path = hf_hub_download(repo_id=index_repo, filename="faiss_index_mapping.csv", repo_type="dataset")
state.index = faiss.read_index(index_path)
state.mapping_df = pd.read_csv(csv_path)
state.mapping_df['image_url'] = state.mapping_df['file_path'].apply(
lambda x: construct_image_url(x, dataset)
)
if dino_version == "3":
model_name_map = {
"s": "facebook/dinov3-vits16-pretrain-lvd1689m",
"b": "facebook/dinov3-vitb16-pretrain-lvd1689m",
"l": "facebook/dinov3-vitl16-pretrain-lvd1689m"
}
state.model = DINOv3FeatureExtractor(model_type=model_name_map[dino_size])
if is_finetuned:
weights_path = hf_hub_download(repo_id=model_repo, filename="best_model.pth", repo_type="model")
state.model.load_state_dict(torch.load(weights_path, map_location=DEVICE, weights_only=True))
state.model.to(DEVICE)
state.model.eval()
state.transform = get_transforms(dino_version)
state.current_config = {"key": config_key}
state.current_results_text = results_display
return f"βœ… Successfully loaded {dataset}", results_display
except Exception as e:
return f"❌ Error: {str(e)}", "Error loading stats."
def pil_to_base64(pil_img):
"""Converts a PIL Image to a base64 data URI string."""
img_buffer = BytesIO()
pil_img = pil_img.convert("RGB")
pil_img.save(img_buffer, format="JPEG")
byte_data = img_buffer.getvalue()
base64_str = base64.b64encode(byte_data).decode("utf-8")
return f"data:image/jpeg;base64,{base64_str}"
def fetch_image_from_url(url):
try:
if url.startswith("data:image"):
header, encoded = url.split(",", 1)
data = base64.b64decode(encoded)
return Image.open(BytesIO(data)).convert("RGB")
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, headers=headers, timeout=5)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert("RGB")
except Exception:
return Image.new("RGB", (224, 224), color="red")
def get_example_images(num_examples=10):
if state.mapping_df is None: return []
test_df = state.mapping_df[state.mapping_df['split'] == 'test']
if test_df.empty: test_df = state.mapping_df
sample = test_df.sample(n=min(len(test_df), num_examples))
return [(fetch_image_from_url(row['image_url']), row['image_url']) for _, row in sample.iterrows()]
def process_image(image_input, k_neighbors):
if state.model is None or state.index is None:
return [], "⚠️ Please wait for model to load..."
try:
k = int(k_neighbors)
if isinstance(image_input, str):
query_img = fetch_image_from_url(image_input)
else:
query_img = Image.fromarray(image_input) if isinstance(image_input, np.ndarray) else image_input
img_tensor = state.transform(query_img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
embedding = state.model(img_tensor).cpu().numpy().astype(np.float32)
faiss.normalize_L2(embedding)
distances, indices = state.index.search(embedding, k)
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx < 0 or idx >= len(state.mapping_df): continue
row = state.mapping_df.iloc[idx]
url = row['image_url']
class_name = extract_class_name(url)
caption = f"Class: {class_name}\nSim: {dist:.3f}"
res_img = fetch_image_from_url(url)
results.append((res_img, caption))
return results, f"βœ… Found {k} matches."
except Exception as e:
return [], f"❌ Search failed: {str(e)}"
# ==========================================
# 2. HEADLESS API FUNCTIONS
# ==========================================
def get_faiss_samples(index_path, dataset_name, num_samples):
"""
API endpoint function. Takes a FAISS index path, dataset name, and number of samples.
Returns file path, class name, image URL (standard string), and 3D PCA coordinates.
"""
try:
if index_path.endswith('.bin'):
csv_path = index_path.replace('.bin', '_mapping.csv')
elif os.path.isdir(index_path):
csv_path = os.path.join(index_path, 'faiss_index_mapping.csv')
else:
csv_path = index_path
if not os.path.exists(csv_path):
if state.mapping_df is not None:
df = state.mapping_df
else:
return {"error": f"Mapping file not found at {csv_path} and no active memory state."}
else:
df = pd.read_csv(csv_path)
if state.index is not None:
faiss_idx = state.index
else:
if not os.path.exists(index_path):
return {"error": f"FAISS index not found at {index_path} and not in memory."}
faiss_idx = faiss.read_index(index_path)
try:
faiss_idx.reconstruct(0)
except RuntimeError:
try:
faiss_idx.make_direct_map()
except AttributeError:
pass
n = int(num_samples)
sample_df = df.sample(n=min(n, len(df)))
vectors = []
valid_indices = []
for orig_idx, row in sample_df.iterrows():
try:
vec = faiss_idx.reconstruct(int(orig_idx))
vectors.append(vec)
valid_indices.append(orig_idx)
except Exception as e:
continue
vectors = np.array(vectors)
if len(vectors) >= 3:
pca = PCA(n_components=3)
pca_coords = pca.fit_transform(vectors)
state.pca_model = pca # <-- Cache the fitted PCA model
else:
pca_coords = np.zeros((len(vectors), 3))
state.pca_model = None
results = []
for i, orig_idx in enumerate(valid_indices):
row = sample_df.loc[orig_idx]
file_path = str(row.get('file_path', ''))
class_name = extract_class_name(file_path)
if 'image_url' in row and pd.notna(row['image_url']):
img_url = row['image_url']
else:
img_url = construct_image_url(file_path, dataset_name)
clean_path = file_path.replace('\\', '/')
results.append({
"file_path": clean_path,
"class_name": class_name,
"image_url": img_url, # <-- Standard URL (No Base64 overhead)
"pca_3d": pca_coords[i].tolist()
})
return {"samples": results}
except Exception as e:
return {"error": str(e)}
def embed_image_api(image_input, index_path, dataset_name, skip_pca=False):
"""
API endpoint function. Embeds the Image using the model.
If skip_pca is False, projects it into 3D using the cached PCA (or calculates it via index fallback).
Returns it with the raw_vector and Base64 image.
"""
if state.model is None:
return {"error": "Model not loaded. Please trigger 'Re-Load Resources' via UI or API first."}
try:
if isinstance(image_input, str):
query_img = fetch_image_from_url(image_input)
else:
query_img = Image.fromarray(image_input) if isinstance(image_input, np.ndarray) else image_input
img_tensor = state.transform(query_img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
embedding = state.model(img_tensor).cpu().numpy().astype(np.float32)
faiss.normalize_L2(embedding)
if skip_pca:
pca_3d = [0.0, 0.0, 0.0]
else:
# Ensure PCA is cached; if not, rebuild it dynamically from the FAISS index
if state.pca_model is None and index_path:
faiss_idx = None
if state.index is not None:
faiss_idx = state.index
elif os.path.exists(index_path):
faiss_idx = faiss.read_index(index_path)
if faiss_idx is not None:
try:
total_vectors = faiss_idx.ntotal
sample_size = min(250, total_vectors)
np.random.seed(42)
sample_ids = np.random.choice(total_vectors, sample_size, replace=False)
fallback_vectors = []
for orig_idx in sample_ids:
try:
vec = faiss_idx.reconstruct(int(orig_idx))
fallback_vectors.append(vec)
except Exception:
continue
fallback_vectors = np.array(fallback_vectors)
if len(fallback_vectors) >= 3:
pca = PCA(n_components=3)
pca.fit(fallback_vectors)
state.pca_model = pca
except Exception:
pass
# Transform using the PCA model
if state.pca_model is not None:
pca_3d = state.pca_model.transform(embedding)[0].tolist()
else:
pca_3d = [0.0, 0.0, 0.0]
b64_img = pil_to_base64(query_img)
results = [{
"file_path": "uploaded_query_image",
"class_name": "Query",
"image_url": b64_img,
"pca_3d": pca_3d,
"raw_vector": embedding[0].tolist()
}]
return {"samples": results}
except Exception as e:
return {"error": str(e)}
# ==========================================
# 3. UI WRAPPERS & GRADIO UI
# ==========================================
def refresh_examples_wrapper():
return get_example_images(10)
def on_select_example(evt: gr.SelectData, gallery_data, k):
if not gallery_data: return
url = gallery_data[evt.index][1]
return process_image(url, k)
with gr.Blocks(title="DINO Image Retrieval") as demo:
gr.Markdown("# πŸ¦– DINOv3 Image Retrieval System")
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### βš™οΈ Configuration")
inp_dataset = gr.Dropdown(label="Dataset", choices=["Cars196", "CUB", "StanfordOnlineProducts"], value=DEFAULT_DATASET)
with gr.Row():
inp_ver = gr.Dropdown(label="Version", choices=["3"], value=DEFAULT_VERSION)
inp_size = gr.Dropdown(label="Size", choices=["s", "b"], value=DEFAULT_SIZE)
inp_finetune = gr.Checkbox(label="Finetuned?", value=False)
inp_k = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Top-K Matches")
btn_load = gr.Button("Re-Load Resources", variant="secondary")
out_status = gr.Textbox(label="Status", value="Initializing...", interactive=False)
gr.Markdown("### πŸ“Š Performance Stats")
out_results = gr.Markdown(value="Stats will appear here.")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("Select Example"):
btn_refresh_ex = gr.Button("πŸ”„ Refresh Examples")
ex_gallery = gr.Gallery(label="Examples", columns=5, height="auto")
with gr.TabItem("Upload Image"):
inp_img_upload = gr.Image(type="pil", label="Upload Query")
btn_search_upload = gr.Button("πŸ” Search", variant="primary")
gr.Markdown("### Matches")
out_gallery = gr.Gallery(label="Results", columns=5, height="auto")
# --- Hidden API Endpoint Routing ---
# 1. /api/get_samples
api_index_path = gr.Textbox(visible=False)
api_dataset_name = gr.Textbox(visible=False)
api_num_samples = gr.Number(visible=False)
api_samples_output = gr.JSON(visible=False)
api_samples_btn = gr.Button(visible=False)
api_samples_btn.click(
fn=get_faiss_samples,
inputs=[api_index_path, api_dataset_name, api_num_samples],
outputs=[api_samples_output],
api_name="get_samples"
)
# 2. /api/embed
api_embed_img_input = gr.Image(visible=False)
api_skip_pca_input = gr.Checkbox(value=False, visible=False) # <-- New hidden input
api_embed_output = gr.JSON(visible=False)
api_embed_btn = gr.Button(visible=False)
api_embed_btn.click(
fn=embed_image_api,
inputs=[api_embed_img_input, api_index_path, api_dataset_name, api_skip_pca_input],
outputs=[api_embed_output],
api_name="embed"
)
# --- Standard UI Events ---
btn_load.click(load_resources, [inp_dataset, inp_ver, inp_size, inp_finetune], [out_status, out_results]).then(refresh_examples_wrapper, outputs=[ex_gallery])
btn_search_upload.click(process_image, [inp_img_upload, inp_k], [out_gallery, out_status])
btn_refresh_ex.click(refresh_examples_wrapper, outputs=[ex_gallery])
ex_gallery.select(on_select_example, [ex_gallery, inp_k], [out_gallery, out_status])
demo.load(load_resources, [inp_dataset, inp_ver, inp_size, inp_finetune], [out_status, out_results], queue=False).then(refresh_examples_wrapper, outputs=[ex_gallery])
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())