clipick / app.py
faststager's picture
Update app.py
ada4422 verified
import torch
import gradio as gr
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import os
import yaml
import requests
import json
import random
from PIL import Image, ImageOps
from io import BytesIO
from types import SimpleNamespace
from torchvision import transforms
from huggingface_hub import hf_hub_download
import mobileclip
from mobileclip.modules.common.mobileone import reparameterize_model
from model import MobileCLIPRanker
HF_USER_REPO = "Nightfury16/clipick"
HF_FILENAME = "best_model_2602.pth"
CONFIG_PATH = "config.yml"
JSON_DATA_PATH = "combined_unique.json"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_config(path="config.yml"):
if not os.path.exists(path):
return SimpleNamespace(**{
"data": SimpleNamespace(img_size=224),
"model": SimpleNamespace(name="mobileclip2_l14")
})
with open(path, "r") as f:
cfg_dict = yaml.safe_load(f)
def recursive_namespace(d):
if isinstance(d, dict):
for k, v in d.items():
d[k] = recursive_namespace(v)
return SimpleNamespace(**d)
return d
return recursive_namespace(cfg_dict)
groups_data = []
try:
if os.path.exists(JSON_DATA_PATH):
with open(JSON_DATA_PATH, "r") as f:
data = json.load(f)
for group in data.get("groups", []):
urls = group.get("images", [])
if urls:
groups_data.append("\n".join(urls))
print(f"Loaded {len(groups_data)} groups from JSON.")
except Exception as e:
print(f"Error loading JSON data: {e}")
print("--- Loading Ranker Server ---")
print(f"Device: {DEVICE}")
cfg = load_config(CONFIG_PATH)
model = MobileCLIPRanker(cfg)
try:
print(f"Downloading Fine-Tuned weights ({HF_FILENAME}) from {HF_USER_REPO}...")
local_weight_path = hf_hub_download(repo_id=HF_USER_REPO, filename=HF_FILENAME)
checkpoint = torch.load(local_weight_path, map_location=DEVICE)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
raw_state_dict = checkpoint["model_state_dict"]
else:
raw_state_dict = checkpoint
state_dict = {k.replace("module.", ""): v for k, v in raw_state_dict.items()}
model.load_state_dict(state_dict, strict=True)
print("βœ… Weights loaded successfully!")
except Exception as e:
print(f"❌ CRITICAL: Load failed. {e}")
raise e
print("⚑ Reparameterizing MobileCLIP-B for inference speed...")
if hasattr(model, 'backbone'):
model.backbone = reparameterize_model(model.backbone)
model.to(DEVICE)
model.eval()
norm_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.481, 0.457, 0.408), std=(0.268, 0.261, 0.275))
])
def letterbox_image(img, size):
'''Pad image to square to preserve aspect ratio (No distortion)'''
img.thumbnail((size, size), Image.Resampling.BICUBIC)
delta_w = size - img.size[0]
delta_h = size - img.size[1]
padding = (delta_w//2, delta_h//2, delta_w-(delta_w//2), delta_h-(delta_h//2))
return ImageOps.expand(img, padding, fill=(128, 128, 128))
def get_best_image(url_list):
valid_tensors = []
valid_indices = []
clean_urls = []
for u in url_list:
if isinstance(u, str) and u.strip():
clean_urls.append(u.strip())
print(f"Processing {len(clean_urls)} images...")
for i, src in enumerate(clean_urls):
try:
if src.startswith("http"):
resp = requests.get(src, timeout=3)
img = Image.open(BytesIO(resp.content)).convert("RGB")
else:
img = Image.open(src).convert("RGB")
img_padded = letterbox_image(img, cfg.data.img_size)
tensor = norm_transform(img_padded)
valid_tensors.append(tensor)
valid_indices.append(i)
except Exception as e:
print(f"Error loading {src}: {e}")
if not valid_tensors:
return None, []
batch = torch.stack(valid_tensors).unsqueeze(0).to(DEVICE)
valid_len = torch.tensor([len(valid_tensors)]).to(DEVICE)
with torch.no_grad():
scores = model(batch, valid_lens=valid_len).view(-1).cpu().numpy()
results = []
for idx, score in zip(valid_indices, scores):
results.append({"url": clean_urls[idx], "score": float(score)})
results.sort(key=lambda x: x["score"], reverse=True)
return results[0]["url"], results
app = FastAPI()
class RankRequest(BaseModel):
urls: List[str]
@app.post("/api/rank")
async def rank_endpoint(req: RankRequest):
if not req.urls:
raise HTTPException(status_code=400, detail="List of URLs cannot be empty")
best_url, results = get_best_image(req.urls)
if best_url is None:
raise HTTPException(status_code=400, detail="Could not load any images")
return {"best_image": best_url, "ranking": results}
def load_group_by_index(index):
idx = int(index) - 1
if 0 <= idx < len(groups_data): return groups_data[idx]
return "Invalid Index"
def load_random_group():
if not groups_data: return 1, "No data."
rand_idx = random.randint(0, len(groups_data) - 1)
return rand_idx + 1, groups_data[rand_idx]
def gradio_wrapper(text_input):
urls = text_input.split("\n")
best_url, results = get_best_image(urls)
if best_url is None: return None, "Error loading images"
try:
if best_url.startswith("http"):
resp = requests.get(best_url, timeout=3)
best_img_pil = Image.open(BytesIO(resp.content)).convert("RGB")
else:
best_img_pil = Image.open(best_url).convert("RGB")
except: best_img_pil = None
return best_img_pil, results
with gr.Blocks() as demo:
gr.Markdown(f"# 🏠 Real Estate Ranker (Student Model)")
gr.Markdown("Using **MobileCLIP-B** (Distilled) with smart resizing.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Select Data")
with gr.Row():
index_input = gr.Number(value=1, label="Group #", minimum=1, precision=0)
random_btn = gr.Button("🎲 Random", variant="secondary")
load_btn = gr.Button("Load Group", size="sm")
gr.Markdown("### 2. URLs")
input_text = gr.Textbox(label="Image URLs", lines=6)
rank_btn = gr.Button("πŸš€ Rank", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="πŸ† Best Image", type="pil")
output_json = gr.JSON(label="Scores")
random_btn.click(fn=load_random_group, inputs=None, outputs=[index_input, input_text])
load_btn.click(fn=load_group_by_index, inputs=index_input, outputs=input_text)
rank_btn.click(fn=gradio_wrapper, inputs=input_text, outputs=[output_image, output_json])
app = gr.mount_gradio_app(app, demo, path="/")