import torch import gradio as gr from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List import os import yaml import requests import json import random from PIL import Image, ImageOps from io import BytesIO from types import SimpleNamespace from torchvision import transforms from huggingface_hub import hf_hub_download import mobileclip from mobileclip.modules.common.mobileone import reparameterize_model from model import MobileCLIPRanker HF_USER_REPO = "Nightfury16/clipick" HF_FILENAME = "best_model_2602.pth" CONFIG_PATH = "config.yml" JSON_DATA_PATH = "combined_unique.json" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_config(path="config.yml"): if not os.path.exists(path): return SimpleNamespace(**{ "data": SimpleNamespace(img_size=224), "model": SimpleNamespace(name="mobileclip2_l14") }) with open(path, "r") as f: cfg_dict = yaml.safe_load(f) def recursive_namespace(d): if isinstance(d, dict): for k, v in d.items(): d[k] = recursive_namespace(v) return SimpleNamespace(**d) return d return recursive_namespace(cfg_dict) groups_data = [] try: if os.path.exists(JSON_DATA_PATH): with open(JSON_DATA_PATH, "r") as f: data = json.load(f) for group in data.get("groups", []): urls = group.get("images", []) if urls: groups_data.append("\n".join(urls)) print(f"Loaded {len(groups_data)} groups from JSON.") except Exception as e: print(f"Error loading JSON data: {e}") print("--- Loading Ranker Server ---") print(f"Device: {DEVICE}") cfg = load_config(CONFIG_PATH) model = MobileCLIPRanker(cfg) try: print(f"Downloading Fine-Tuned weights ({HF_FILENAME}) from {HF_USER_REPO}...") local_weight_path = hf_hub_download(repo_id=HF_USER_REPO, filename=HF_FILENAME) checkpoint = torch.load(local_weight_path, map_location=DEVICE) if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: raw_state_dict = checkpoint["model_state_dict"] else: raw_state_dict = checkpoint state_dict = {k.replace("module.", ""): v for k, v in raw_state_dict.items()} model.load_state_dict(state_dict, strict=True) print("✅ Weights loaded successfully!") except Exception as e: print(f"❌ CRITICAL: Load failed. {e}") raise e print("⚡ Reparameterizing MobileCLIP-B for inference speed...") if hasattr(model, 'backbone'): model.backbone = reparameterize_model(model.backbone) model.to(DEVICE) model.eval() norm_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.481, 0.457, 0.408), std=(0.268, 0.261, 0.275)) ]) def letterbox_image(img, size): '''Pad image to square to preserve aspect ratio (No distortion)''' img.thumbnail((size, size), Image.Resampling.BICUBIC) delta_w = size - img.size[0] delta_h = size - img.size[1] padding = (delta_w//2, delta_h//2, delta_w-(delta_w//2), delta_h-(delta_h//2)) return ImageOps.expand(img, padding, fill=(128, 128, 128)) def get_best_image(url_list): valid_tensors = [] valid_indices = [] clean_urls = [] for u in url_list: if isinstance(u, str) and u.strip(): clean_urls.append(u.strip()) print(f"Processing {len(clean_urls)} images...") for i, src in enumerate(clean_urls): try: if src.startswith("http"): resp = requests.get(src, timeout=3) img = Image.open(BytesIO(resp.content)).convert("RGB") else: img = Image.open(src).convert("RGB") img_padded = letterbox_image(img, cfg.data.img_size) tensor = norm_transform(img_padded) valid_tensors.append(tensor) valid_indices.append(i) except Exception as e: print(f"Error loading {src}: {e}") if not valid_tensors: return None, [] batch = torch.stack(valid_tensors).unsqueeze(0).to(DEVICE) valid_len = torch.tensor([len(valid_tensors)]).to(DEVICE) with torch.no_grad(): scores = model(batch, valid_lens=valid_len).view(-1).cpu().numpy() results = [] for idx, score in zip(valid_indices, scores): results.append({"url": clean_urls[idx], "score": float(score)}) results.sort(key=lambda x: x["score"], reverse=True) return results[0]["url"], results app = FastAPI() class RankRequest(BaseModel): urls: List[str] @app.post("/api/rank") async def rank_endpoint(req: RankRequest): if not req.urls: raise HTTPException(status_code=400, detail="List of URLs cannot be empty") best_url, results = get_best_image(req.urls) if best_url is None: raise HTTPException(status_code=400, detail="Could not load any images") return {"best_image": best_url, "ranking": results} def load_group_by_index(index): idx = int(index) - 1 if 0 <= idx < len(groups_data): return groups_data[idx] return "Invalid Index" def load_random_group(): if not groups_data: return 1, "No data." rand_idx = random.randint(0, len(groups_data) - 1) return rand_idx + 1, groups_data[rand_idx] def gradio_wrapper(text_input): urls = text_input.split("\n") best_url, results = get_best_image(urls) if best_url is None: return None, "Error loading images" try: if best_url.startswith("http"): resp = requests.get(best_url, timeout=3) best_img_pil = Image.open(BytesIO(resp.content)).convert("RGB") else: best_img_pil = Image.open(best_url).convert("RGB") except: best_img_pil = None return best_img_pil, results with gr.Blocks() as demo: gr.Markdown(f"# 🏠 Real Estate Ranker (Student Model)") gr.Markdown("Using **MobileCLIP-B** (Distilled) with smart resizing.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Select Data") with gr.Row(): index_input = gr.Number(value=1, label="Group #", minimum=1, precision=0) random_btn = gr.Button("🎲 Random", variant="secondary") load_btn = gr.Button("Load Group", size="sm") gr.Markdown("### 2. URLs") input_text = gr.Textbox(label="Image URLs", lines=6) rank_btn = gr.Button("🚀 Rank", variant="primary") with gr.Column(scale=1): output_image = gr.Image(label="🏆 Best Image", type="pil") output_json = gr.JSON(label="Scores") random_btn.click(fn=load_random_group, inputs=None, outputs=[index_input, input_text]) load_btn.click(fn=load_group_by_index, inputs=index_input, outputs=input_text) rank_btn.click(fn=gradio_wrapper, inputs=input_text, outputs=[output_image, output_json]) app = gr.mount_gradio_app(app, demo, path="/")