Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import List | |
| import os | |
| import yaml | |
| import requests | |
| import json | |
| import random | |
| from PIL import Image, ImageOps | |
| from io import BytesIO | |
| from types import SimpleNamespace | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| import mobileclip | |
| from mobileclip.modules.common.mobileone import reparameterize_model | |
| from model import MobileCLIPRanker | |
| HF_USER_REPO = "Nightfury16/clipick" | |
| HF_FILENAME = "best_model_2602.pth" | |
| CONFIG_PATH = "config.yml" | |
| JSON_DATA_PATH = "combined_unique.json" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_config(path="config.yml"): | |
| if not os.path.exists(path): | |
| return SimpleNamespace(**{ | |
| "data": SimpleNamespace(img_size=224), | |
| "model": SimpleNamespace(name="mobileclip2_l14") | |
| }) | |
| with open(path, "r") as f: | |
| cfg_dict = yaml.safe_load(f) | |
| def recursive_namespace(d): | |
| if isinstance(d, dict): | |
| for k, v in d.items(): | |
| d[k] = recursive_namespace(v) | |
| return SimpleNamespace(**d) | |
| return d | |
| return recursive_namespace(cfg_dict) | |
| groups_data = [] | |
| try: | |
| if os.path.exists(JSON_DATA_PATH): | |
| with open(JSON_DATA_PATH, "r") as f: | |
| data = json.load(f) | |
| for group in data.get("groups", []): | |
| urls = group.get("images", []) | |
| if urls: | |
| groups_data.append("\n".join(urls)) | |
| print(f"Loaded {len(groups_data)} groups from JSON.") | |
| except Exception as e: | |
| print(f"Error loading JSON data: {e}") | |
| print("--- Loading Ranker Server ---") | |
| print(f"Device: {DEVICE}") | |
| cfg = load_config(CONFIG_PATH) | |
| model = MobileCLIPRanker(cfg) | |
| try: | |
| print(f"Downloading Fine-Tuned weights ({HF_FILENAME}) from {HF_USER_REPO}...") | |
| local_weight_path = hf_hub_download(repo_id=HF_USER_REPO, filename=HF_FILENAME) | |
| checkpoint = torch.load(local_weight_path, map_location=DEVICE) | |
| if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: | |
| raw_state_dict = checkpoint["model_state_dict"] | |
| else: | |
| raw_state_dict = checkpoint | |
| state_dict = {k.replace("module.", ""): v for k, v in raw_state_dict.items()} | |
| model.load_state_dict(state_dict, strict=True) | |
| print("β Weights loaded successfully!") | |
| except Exception as e: | |
| print(f"β CRITICAL: Load failed. {e}") | |
| raise e | |
| print("β‘ Reparameterizing MobileCLIP-B for inference speed...") | |
| if hasattr(model, 'backbone'): | |
| model.backbone = reparameterize_model(model.backbone) | |
| model.to(DEVICE) | |
| model.eval() | |
| norm_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.481, 0.457, 0.408), std=(0.268, 0.261, 0.275)) | |
| ]) | |
| def letterbox_image(img, size): | |
| '''Pad image to square to preserve aspect ratio (No distortion)''' | |
| img.thumbnail((size, size), Image.Resampling.BICUBIC) | |
| delta_w = size - img.size[0] | |
| delta_h = size - img.size[1] | |
| padding = (delta_w//2, delta_h//2, delta_w-(delta_w//2), delta_h-(delta_h//2)) | |
| return ImageOps.expand(img, padding, fill=(128, 128, 128)) | |
| def get_best_image(url_list): | |
| valid_tensors = [] | |
| valid_indices = [] | |
| clean_urls = [] | |
| for u in url_list: | |
| if isinstance(u, str) and u.strip(): | |
| clean_urls.append(u.strip()) | |
| print(f"Processing {len(clean_urls)} images...") | |
| for i, src in enumerate(clean_urls): | |
| try: | |
| if src.startswith("http"): | |
| resp = requests.get(src, timeout=3) | |
| img = Image.open(BytesIO(resp.content)).convert("RGB") | |
| else: | |
| img = Image.open(src).convert("RGB") | |
| img_padded = letterbox_image(img, cfg.data.img_size) | |
| tensor = norm_transform(img_padded) | |
| valid_tensors.append(tensor) | |
| valid_indices.append(i) | |
| except Exception as e: | |
| print(f"Error loading {src}: {e}") | |
| if not valid_tensors: | |
| return None, [] | |
| batch = torch.stack(valid_tensors).unsqueeze(0).to(DEVICE) | |
| valid_len = torch.tensor([len(valid_tensors)]).to(DEVICE) | |
| with torch.no_grad(): | |
| scores = model(batch, valid_lens=valid_len).view(-1).cpu().numpy() | |
| results = [] | |
| for idx, score in zip(valid_indices, scores): | |
| results.append({"url": clean_urls[idx], "score": float(score)}) | |
| results.sort(key=lambda x: x["score"], reverse=True) | |
| return results[0]["url"], results | |
| app = FastAPI() | |
| class RankRequest(BaseModel): | |
| urls: List[str] | |
| async def rank_endpoint(req: RankRequest): | |
| if not req.urls: | |
| raise HTTPException(status_code=400, detail="List of URLs cannot be empty") | |
| best_url, results = get_best_image(req.urls) | |
| if best_url is None: | |
| raise HTTPException(status_code=400, detail="Could not load any images") | |
| return {"best_image": best_url, "ranking": results} | |
| def load_group_by_index(index): | |
| idx = int(index) - 1 | |
| if 0 <= idx < len(groups_data): return groups_data[idx] | |
| return "Invalid Index" | |
| def load_random_group(): | |
| if not groups_data: return 1, "No data." | |
| rand_idx = random.randint(0, len(groups_data) - 1) | |
| return rand_idx + 1, groups_data[rand_idx] | |
| def gradio_wrapper(text_input): | |
| urls = text_input.split("\n") | |
| best_url, results = get_best_image(urls) | |
| if best_url is None: return None, "Error loading images" | |
| try: | |
| if best_url.startswith("http"): | |
| resp = requests.get(best_url, timeout=3) | |
| best_img_pil = Image.open(BytesIO(resp.content)).convert("RGB") | |
| else: | |
| best_img_pil = Image.open(best_url).convert("RGB") | |
| except: best_img_pil = None | |
| return best_img_pil, results | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# π Real Estate Ranker (Student Model)") | |
| gr.Markdown("Using **MobileCLIP-B** (Distilled) with smart resizing.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Select Data") | |
| with gr.Row(): | |
| index_input = gr.Number(value=1, label="Group #", minimum=1, precision=0) | |
| random_btn = gr.Button("π² Random", variant="secondary") | |
| load_btn = gr.Button("Load Group", size="sm") | |
| gr.Markdown("### 2. URLs") | |
| input_text = gr.Textbox(label="Image URLs", lines=6) | |
| rank_btn = gr.Button("π Rank", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="π Best Image", type="pil") | |
| output_json = gr.JSON(label="Scores") | |
| random_btn.click(fn=load_random_group, inputs=None, outputs=[index_input, input_text]) | |
| load_btn.click(fn=load_group_by_index, inputs=index_input, outputs=input_text) | |
| rank_btn.click(fn=gradio_wrapper, inputs=input_text, outputs=[output_image, output_json]) | |
| app = gr.mount_gradio_app(app, demo, path="/") |