Nightfury16 Claude Opus 4.6 commited on
Commit
3aa023a
·
1 Parent(s): ada4422

Fix mobileclip2_l14 checkpoint loading and add pyvips acceleration

Browse files

- Switch model from mobileclip_b to mobileclip2_l14 matching checkpoint 2602
- Fix head architecture: nn.Linear -> 2-layer MLP (Linear/GELU/Dropout/Linear)
matching training code's RankingHead with head.net.{0,3} key layout
- Use GELU activation (not ReLU) to match training exactly
- Infer head_hidden_dim from checkpoint at load time
- Remove reparameterize_model (MobileOne-specific, not applicable to ViT-L/14)
- Replace PIL with pyvips (shrink-on-load thumbnail_buffer for fast JPEG decode)
- Replace sequential requests with urllib3 PoolManager + ThreadPoolExecutor(16)
- Add torch.inference_mode, fp16 autocast on CUDA, torch.compile
- Add packages.txt (libvips-dev) for HF Spaces

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (5) hide show
  1. app.py +180 -135
  2. config.yml +1 -1
  3. model.py +51 -46
  4. packages.txt +1 -0
  5. requirements.txt +3 -4
app.py CHANGED
@@ -1,53 +1,45 @@
1
  import torch
 
 
2
  import gradio as gr
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from typing import List
 
 
 
 
6
  import os
7
- import yaml
8
- import requests
9
  import json
10
  import random
11
- from PIL import Image, ImageOps
12
- from io import BytesIO
13
- from types import SimpleNamespace
14
- from torchvision import transforms
15
- from huggingface_hub import hf_hub_download
16
 
17
- import mobileclip
18
- from mobileclip.modules.common.mobileone import reparameterize_model
19
  from model import MobileCLIPRanker
20
 
 
21
  HF_USER_REPO = "Nightfury16/clipick"
22
- HF_FILENAME = "best_model_2602.pth"
23
- CONFIG_PATH = "config.yml"
24
  JSON_DATA_PATH = "combined_unique.json"
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
26
 
27
- def load_config(path="config.yml"):
28
- if not os.path.exists(path):
29
- return SimpleNamespace(**{
30
- "data": SimpleNamespace(img_size=224),
31
- "model": SimpleNamespace(name="mobileclip2_l14")
32
- })
33
-
34
- with open(path, "r") as f:
35
- cfg_dict = yaml.safe_load(f)
36
-
37
- def recursive_namespace(d):
38
- if isinstance(d, dict):
39
- for k, v in d.items():
40
- d[k] = recursive_namespace(v)
41
- return SimpleNamespace(**d)
42
- return d
43
- return recursive_namespace(cfg_dict)
44
 
 
 
 
 
 
 
 
 
 
45
  groups_data = []
46
  try:
47
  if os.path.exists(JSON_DATA_PATH):
48
  with open(JSON_DATA_PATH, "r") as f:
49
- data = json.load(f)
50
- for group in data.get("groups", []):
51
  urls = group.get("images", [])
52
  if urls:
53
  groups_data.append("\n".join(urls))
@@ -55,151 +47,204 @@ try:
55
  except Exception as e:
56
  print(f"Error loading JSON data: {e}")
57
 
 
58
  print("--- Loading Ranker Server ---")
59
  print(f"Device: {DEVICE}")
60
 
61
- cfg = load_config(CONFIG_PATH)
62
- model = MobileCLIPRanker(cfg)
 
 
63
 
64
- try:
65
- print(f"Downloading Fine-Tuned weights ({HF_FILENAME}) from {HF_USER_REPO}...")
66
- local_weight_path = hf_hub_download(repo_id=HF_USER_REPO, filename=HF_FILENAME)
67
-
68
- checkpoint = torch.load(local_weight_path, map_location=DEVICE)
69
-
70
- if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
71
- raw_state_dict = checkpoint["model_state_dict"]
72
- else:
73
- raw_state_dict = checkpoint
74
-
75
- state_dict = {k.replace("module.", ""): v for k, v in raw_state_dict.items()}
76
-
77
- model.load_state_dict(state_dict, strict=True)
78
- print("✅ Weights loaded successfully!")
79
-
80
- except Exception as e:
81
- print(f"❌ CRITICAL: Load failed. {e}")
82
- raise e
83
 
84
- print("⚡ Reparameterizing MobileCLIP-B for inference speed...")
85
- if hasattr(model, 'backbone'):
86
- model.backbone = reparameterize_model(model.backbone)
 
87
 
88
- model.to(DEVICE)
89
- model.eval()
90
 
91
- norm_transform = transforms.Compose([
92
- transforms.ToTensor(),
93
- transforms.Normalize(mean=(0.481, 0.457, 0.408), std=(0.268, 0.261, 0.275))
94
- ])
 
 
 
95
 
96
- def letterbox_image(img, size):
97
- '''Pad image to square to preserve aspect ratio (No distortion)'''
98
- img.thumbnail((size, size), Image.Resampling.BICUBIC)
99
- delta_w = size - img.size[0]
100
- delta_h = size - img.size[1]
101
- padding = (delta_w//2, delta_h//2, delta_w-(delta_w//2), delta_h-(delta_h//2))
102
- return ImageOps.expand(img, padding, fill=(128, 128, 128))
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def get_best_image(url_list):
105
- valid_tensors = []
106
- valid_indices = []
107
- clean_urls = []
108
-
109
- for u in url_list:
110
- if isinstance(u, str) and u.strip():
111
- clean_urls.append(u.strip())
112
-
113
- print(f"Processing {len(clean_urls)} images...")
114
-
115
- for i, src in enumerate(clean_urls):
116
- try:
117
- if src.startswith("http"):
118
- resp = requests.get(src, timeout=3)
119
- img = Image.open(BytesIO(resp.content)).convert("RGB")
120
- else:
121
- img = Image.open(src).convert("RGB")
122
-
123
- img_padded = letterbox_image(img, cfg.data.img_size)
124
- tensor = norm_transform(img_padded)
125
-
126
- valid_tensors.append(tensor)
127
- valid_indices.append(i)
128
- except Exception as e:
129
- print(f"Error loading {src}: {e}")
130
-
131
- if not valid_tensors:
132
  return None, []
133
 
134
- batch = torch.stack(valid_tensors).unsqueeze(0).to(DEVICE)
135
- valid_len = torch.tensor([len(valid_tensors)]).to(DEVICE)
136
-
137
- with torch.no_grad():
138
- scores = model(batch, valid_lens=valid_len).view(-1).cpu().numpy()
139
-
140
- results = []
141
- for idx, score in zip(valid_indices, scores):
142
- results.append({"url": clean_urls[idx], "score": float(score)})
 
 
 
 
 
143
 
144
- results.sort(key=lambda x: x["score"], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  return results[0]["url"], results
146
 
 
 
147
  app = FastAPI()
148
 
 
149
  class RankRequest(BaseModel):
150
  urls: List[str]
151
 
 
152
  @app.post("/api/rank")
153
  async def rank_endpoint(req: RankRequest):
154
  if not req.urls:
155
- raise HTTPException(status_code=400, detail="List of URLs cannot be empty")
156
- best_url, results = get_best_image(req.urls)
157
- if best_url is None:
158
- raise HTTPException(status_code=400, detail="Could not load any images")
159
- return {"best_image": best_url, "ranking": results}
 
160
 
 
161
  def load_group_by_index(index):
162
  idx = int(index) - 1
163
- if 0 <= idx < len(groups_data): return groups_data[idx]
164
- return "Invalid Index"
165
 
166
  def load_random_group():
167
- if not groups_data: return 1, "No data."
168
- rand_idx = random.randint(0, len(groups_data) - 1)
169
- return rand_idx + 1, groups_data[rand_idx]
 
 
170
 
171
  def gradio_wrapper(text_input):
172
- urls = text_input.split("\n")
173
- best_url, results = get_best_image(urls)
174
- if best_url is None: return None, "Error loading images"
175
- try:
176
- if best_url.startswith("http"):
177
- resp = requests.get(best_url, timeout=3)
178
- best_img_pil = Image.open(BytesIO(resp.content)).convert("RGB")
179
- else:
180
- best_img_pil = Image.open(best_url).convert("RGB")
181
- except: best_img_pil = None
182
- return best_img_pil, results
183
 
184
  with gr.Blocks() as demo:
185
- gr.Markdown(f"# 🏠 Real Estate Ranker (Student Model)")
186
- gr.Markdown("Using **MobileCLIP-B** (Distilled) with smart resizing.")
187
  with gr.Row():
188
  with gr.Column(scale=1):
189
  gr.Markdown("### 1. Select Data")
190
  with gr.Row():
191
- index_input = gr.Number(value=1, label="Group #", minimum=1, precision=0)
192
- random_btn = gr.Button("🎲 Random", variant="secondary")
 
 
193
  load_btn = gr.Button("Load Group", size="sm")
194
  gr.Markdown("### 2. URLs")
195
  input_text = gr.Textbox(label="Image URLs", lines=6)
196
- rank_btn = gr.Button("🚀 Rank", variant="primary")
197
  with gr.Column(scale=1):
198
- output_image = gr.Image(label="🏆 Best Image", type="pil")
199
  output_json = gr.JSON(label="Scores")
200
-
201
- random_btn.click(fn=load_random_group, inputs=None, outputs=[index_input, input_text])
202
  load_btn.click(fn=load_group_by_index, inputs=index_input, outputs=input_text)
203
- rank_btn.click(fn=gradio_wrapper, inputs=input_text, outputs=[output_image, output_json])
 
 
204
 
205
- app = gr.mount_gradio_app(app, demo, path="/")
 
1
  import torch
2
+ import numpy as np
3
+ import pyvips
4
  import gradio as gr
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
  from typing import List
8
+ from contextlib import nullcontext
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from huggingface_hub import hf_hub_download
11
+ import urllib3
12
  import os
 
 
13
  import json
14
  import random
 
 
 
 
 
15
 
 
 
16
  from model import MobileCLIPRanker
17
 
18
+ # ── Config ──────────────────────────────────────────────────────────────
19
  HF_USER_REPO = "Nightfury16/clipick"
20
+ HF_FILENAME = "best_model_2602.pth"
 
21
  JSON_DATA_PATH = "combined_unique.json"
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
+ IMG_SIZE = 224
24
 
25
+ # Normalisation constants (pre-shaped for numpy broadcast)
26
+ MEAN = np.float32([0.481, 0.457, 0.408]).reshape(1, 1, 3)
27
+ INV_STD = (1.0 / np.float32([0.268, 0.261, 0.275])).reshape(1, 1, 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # ── Connection & thread pools ───────────────────────────────────────────
30
+ http_pool = urllib3.PoolManager(
31
+ maxsize=32,
32
+ retries=urllib3.Retry(total=1, backoff_factor=0),
33
+ timeout=urllib3.Timeout(connect=2.0, read=3.0),
34
+ )
35
+ fetch_pool = ThreadPoolExecutor(max_workers=16)
36
+
37
+ # ── Load group data ─────────────────────────────────────────────────────
38
  groups_data = []
39
  try:
40
  if os.path.exists(JSON_DATA_PATH):
41
  with open(JSON_DATA_PATH, "r") as f:
42
+ for group in json.load(f).get("groups", []):
 
43
  urls = group.get("images", [])
44
  if urls:
45
  groups_data.append("\n".join(urls))
 
47
  except Exception as e:
48
  print(f"Error loading JSON data: {e}")
49
 
50
+ # ── Load model ──────────────────────────────────────────────────────────
51
  print("--- Loading Ranker Server ---")
52
  print(f"Device: {DEVICE}")
53
 
54
+ # 1. Download fine-tuned checkpoint first to infer head dimensions
55
+ print(f"Downloading fine-tuned weights ({HF_FILENAME})...")
56
+ local_weight_path = hf_hub_download(repo_id=HF_USER_REPO, filename=HF_FILENAME)
57
+ checkpoint = torch.load(local_weight_path, map_location=DEVICE)
58
 
59
+ raw_sd = (
60
+ checkpoint.get("model_state_dict", checkpoint)
61
+ if isinstance(checkpoint, dict)
62
+ else checkpoint
63
+ )
64
+ state_dict = {k.replace("module.", ""): v for k, v in raw_sd.items()}
65
+
66
+ # Infer hidden dim from checkpoint so architecture matches exactly
67
+ head_hidden = state_dict["head.net.0.weight"].shape[0]
68
+ print(f"Head hidden dim inferred from checkpoint: {head_hidden}")
 
 
 
 
 
 
 
 
 
69
 
70
+ # 2. Build model with matching architecture, load weights
71
+ model = MobileCLIPRanker(backbone_dim=768, head_hidden_dim=head_hidden)
72
+ model.load_state_dict(state_dict, strict=True)
73
+ print("Weights loaded successfully.")
74
 
75
+ model.to(DEVICE).eval()
 
76
 
77
+ # 3. Compile for faster inference on CUDA
78
+ if DEVICE == "cuda" and hasattr(torch, "compile"):
79
+ try:
80
+ model = torch.compile(model, mode="reduce-overhead")
81
+ print("Model compiled with torch.compile (reduce-overhead)")
82
+ except Exception:
83
+ pass
84
 
 
 
 
 
 
 
 
85
 
86
+ # ── Image processing (pyvips) ──────────────────────────────────────────
87
+ def _fetch_and_preprocess(url: str):
88
+ """Fetch one image, letterbox-resize, normalise -> CHW float32 numpy."""
89
+ try:
90
+ if url.startswith("http"):
91
+ resp = http_pool.request("GET", url, preload_content=True)
92
+ if resp.status != 200:
93
+ return None
94
+ # thumbnail_buffer uses shrink-on-load (fast JPEG decode)
95
+ img = pyvips.Image.thumbnail_buffer(
96
+ resp.data, IMG_SIZE, height=IMG_SIZE
97
+ )
98
+ else:
99
+ img = pyvips.Image.thumbnail(url, IMG_SIZE, height=IMG_SIZE)
100
+
101
+ # Ensure 3-band sRGB
102
+ if img.bands == 4:
103
+ img = img.flatten(background=[128, 128, 128])
104
+ elif img.bands == 1:
105
+ img = img.colourspace("srgb")
106
+
107
+ # Letterbox pad to exact IMG_SIZE x IMG_SIZE
108
+ if img.width != IMG_SIZE or img.height != IMG_SIZE:
109
+ img = img.gravity(
110
+ "centre", IMG_SIZE, IMG_SIZE,
111
+ extend="background", background=[128, 128, 128],
112
+ )
113
+
114
+ # -> float32 CHW normalised numpy
115
+ arr = np.ndarray(
116
+ buffer=img.write_to_memory(),
117
+ dtype=np.uint8,
118
+ shape=(IMG_SIZE, IMG_SIZE, 3),
119
+ )
120
+ arr = (arr.astype(np.float32) * (1.0 / 255.0) - MEAN) * INV_STD
121
+ return arr.transpose(2, 0, 1) # HWC -> CHW
122
+ except Exception:
123
+ return None
124
+
125
+
126
+ def _fetch_display(url: str):
127
+ """Fetch image for Gradio display -> numpy uint8 HWC."""
128
+ try:
129
+ if url.startswith("http"):
130
+ resp = http_pool.request("GET", url, preload_content=True)
131
+ img = pyvips.Image.new_from_buffer(resp.data, "")
132
+ else:
133
+ img = pyvips.Image.new_from_file(url, access="sequential")
134
+ if img.bands == 4:
135
+ img = img.flatten(background=[255, 255, 255])
136
+ elif img.bands == 1:
137
+ img = img.colourspace("srgb")
138
+ return np.ndarray(
139
+ buffer=img.write_to_memory(),
140
+ dtype=np.uint8,
141
+ shape=(img.height, img.width, 3),
142
+ )
143
+ except Exception:
144
+ return None
145
+
146
+
147
+ # ── Core ranking logic ──────────────────────────────────────────────────
148
  def get_best_image(url_list):
149
+ clean = [u.strip() for u in url_list if isinstance(u, str) and u.strip()]
150
+ if not clean:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  return None, []
152
 
153
+ # Parallel fetch + preprocess across thread pool
154
+ futures = {
155
+ fetch_pool.submit(_fetch_and_preprocess, u): i
156
+ for i, u in enumerate(clean)
157
+ }
158
+ arrays, indices = [], []
159
+ for fut in futures:
160
+ arr = fut.result()
161
+ if arr is not None:
162
+ arrays.append(arr)
163
+ indices.append(futures[fut])
164
+
165
+ if not arrays:
166
+ return None, []
167
 
168
+ batch = torch.from_numpy(np.stack(arrays)).unsqueeze(0).to(DEVICE)
169
+ vlens = torch.tensor([len(arrays)], device=DEVICE)
170
+
171
+ amp_ctx = (
172
+ torch.autocast(device_type="cuda", dtype=torch.float16)
173
+ if DEVICE == "cuda"
174
+ else nullcontext()
175
+ )
176
+ with torch.inference_mode(), amp_ctx:
177
+ scores = model(batch, valid_lens=vlens).view(-1).cpu().numpy()
178
+
179
+ results = sorted(
180
+ [{"url": clean[i], "score": float(s)} for i, s in zip(indices, scores)],
181
+ key=lambda r: r["score"],
182
+ reverse=True,
183
+ )
184
  return results[0]["url"], results
185
 
186
+
187
+ # ── FastAPI ─────────────────────────────────────────────────────────────
188
  app = FastAPI()
189
 
190
+
191
  class RankRequest(BaseModel):
192
  urls: List[str]
193
 
194
+
195
  @app.post("/api/rank")
196
  async def rank_endpoint(req: RankRequest):
197
  if not req.urls:
198
+ raise HTTPException(400, "URL list is empty")
199
+ best, results = get_best_image(req.urls)
200
+ if best is None:
201
+ raise HTTPException(400, "No images could be loaded")
202
+ return {"best_image": best, "ranking": results}
203
+
204
 
205
+ # ── Gradio UI ───────────────────────────────────────────────────────────
206
  def load_group_by_index(index):
207
  idx = int(index) - 1
208
+ return groups_data[idx] if 0 <= idx < len(groups_data) else "Invalid index"
209
+
210
 
211
  def load_random_group():
212
+ if not groups_data:
213
+ return 1, "No data."
214
+ i = random.randint(0, len(groups_data) - 1)
215
+ return i + 1, groups_data[i]
216
+
217
 
218
  def gradio_wrapper(text_input):
219
+ best_url, results = get_best_image(text_input.split("\n"))
220
+ if best_url is None:
221
+ return None, "Error loading images"
222
+ return _fetch_display(best_url), results
223
+
 
 
 
 
 
 
224
 
225
  with gr.Blocks() as demo:
226
+ gr.Markdown("# Real Estate Image Ranker")
227
+ gr.Markdown("**MobileCLIP2-L14** fine-tuned ranker with pyvips acceleration.")
228
  with gr.Row():
229
  with gr.Column(scale=1):
230
  gr.Markdown("### 1. Select Data")
231
  with gr.Row():
232
+ index_input = gr.Number(
233
+ value=1, label="Group #", minimum=1, precision=0
234
+ )
235
+ random_btn = gr.Button("Random", variant="secondary")
236
  load_btn = gr.Button("Load Group", size="sm")
237
  gr.Markdown("### 2. URLs")
238
  input_text = gr.Textbox(label="Image URLs", lines=6)
239
+ rank_btn = gr.Button("Rank", variant="primary")
240
  with gr.Column(scale=1):
241
+ output_image = gr.Image(label="Best Image", type="numpy")
242
  output_json = gr.JSON(label="Scores")
243
+
244
+ random_btn.click(fn=load_random_group, outputs=[index_input, input_text])
245
  load_btn.click(fn=load_group_by_index, inputs=index_input, outputs=input_text)
246
+ rank_btn.click(
247
+ fn=gradio_wrapper, inputs=input_text, outputs=[output_image, output_json]
248
+ )
249
 
250
+ app = gr.mount_gradio_app(app, demo, path="/")
config.yml CHANGED
@@ -1,4 +1,4 @@
1
  data:
2
  img_size: 224
3
  model:
4
- name: "mobileclip_b"
 
1
  data:
2
  img_size: 224
3
  model:
4
+ name: "mobileclip2_l14"
model.py CHANGED
@@ -1,56 +1,61 @@
1
  import torch
2
  import torch.nn as nn
3
- import mobileclip
4
  import open_clip
5
  from huggingface_hub import hf_hub_download
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class MobileCLIPRanker(nn.Module):
8
- def __init__(self, cfg):
9
  super().__init__()
10
-
11
- model_name = cfg.model.name.lower()
12
- self.model_type = "mobileclip"
13
-
14
- if "l14" in model_name or "l-14" in model_name:
15
- self.model_type = "open_clip"
16
- repo_id = "apple/MobileCLIP2-L-14"
17
- filename = "mobileclip2_l14.pt"
18
- self.backbone_dim = 768
19
-
20
- print(f"Initializing Teacher (L14)...")
21
- ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
22
- model, _, _ = open_clip.create_model_and_transforms('MobileCLIP2-L-14', pretrained=ckpt_path)
23
- self.backbone = model.visual
24
-
25
- else:
26
- repo_id = "apple/MobileCLIP2-B"
27
- filename = "mobileclip2_b.pt"
28
- arch = "mobileclip_b"
29
- self.backbone_dim = 512
30
-
31
- print(f"Initializing Student ({arch})...")
32
- ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
33
- model, _, _ = mobileclip.create_model_and_transforms(arch, pretrained=ckpt_path)
34
- self.backbone = model.image_encoder
35
-
36
- for param in self.backbone.parameters():
37
- param.requires_grad = False
38
-
39
- self.head = nn.Linear(self.backbone_dim, 1)
40
 
41
  def forward(self, x, valid_lens=None):
42
- b, g, c, h, w = x.shape
43
- x_flat = x.view(b * g, c, h, w)
44
-
45
- if self.model_type == "open_clip":
46
- features = self.backbone(x_flat)
47
  else:
48
- features = self.backbone(x_flat)
49
-
50
- if isinstance(features, tuple):
51
- features = features[0]
52
-
53
- features = features.view(b, g, -1)
54
- scores = self.head(features)
55
-
56
- return scores
 
1
  import torch
2
  import torch.nn as nn
 
3
  import open_clip
4
  from huggingface_hub import hf_hub_download
5
 
6
+
7
+ class RankingHead(nn.Module):
8
+ """2-layer MLP head with dropout — matches training checkpoint layout:
9
+ head.net.0 Linear(in_dim, hidden_dim)
10
+ head.net.1 GELU
11
+ head.net.2 Dropout
12
+ head.net.3 Linear(hidden_dim, 1)
13
+ """
14
+
15
+ def __init__(self, in_dim, hidden_dim=256, dropout=0.1):
16
+ super().__init__()
17
+ self.net = nn.Sequential(
18
+ nn.Linear(in_dim, hidden_dim),
19
+ nn.GELU(),
20
+ nn.Dropout(dropout),
21
+ nn.Linear(hidden_dim, 1),
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.net(x)
26
+
27
+
28
  class MobileCLIPRanker(nn.Module):
29
+ def __init__(self, backbone_dim=768, head_hidden_dim=256, head_dropout=0.1):
30
  super().__init__()
31
+ self.backbone_dim = backbone_dim
32
+
33
+ print("Initializing MobileCLIP2-L14 backbone...")
34
+ ckpt_path = hf_hub_download(
35
+ repo_id="apple/MobileCLIP2-L-14",
36
+ filename="mobileclip2_l14.pt",
37
+ )
38
+ model, _, _ = open_clip.create_model_and_transforms(
39
+ "MobileCLIP2-L-14", pretrained=ckpt_path
40
+ )
41
+ self.backbone = model.visual
42
+
43
+ self.backbone.eval()
44
+ for p in self.backbone.parameters():
45
+ p.requires_grad = False
46
+
47
+ self.head = RankingHead(backbone_dim, head_hidden_dim, head_dropout)
48
+
49
+ def train(self, mode=True):
50
+ super().train(mode)
51
+ self.backbone.eval()
52
+ return self
 
 
 
 
 
 
 
 
53
 
54
  def forward(self, x, valid_lens=None):
55
+ if x.dim() == 5:
56
+ b, g, c, h, w = x.shape
57
+ features = self.backbone(x.view(b * g, c, h, w))
58
+ features = features.view(b, g, -1)
 
59
  else:
60
+ features = x
61
+ return self.head(features)
 
 
 
 
 
 
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libvips-dev
requirements.txt CHANGED
@@ -4,10 +4,9 @@ fastapi
4
  uvicorn
5
  gradio
6
  pydantic
7
- requests
8
  pyyaml
9
- pillow
10
  huggingface_hub
11
  timm
12
- git+https://github.com/apple/ml-mobileclip.git
13
- open_clip_torch
 
 
4
  uvicorn
5
  gradio
6
  pydantic
 
7
  pyyaml
 
8
  huggingface_hub
9
  timm
10
+ open_clip_torch
11
+ pyvips
12
+ urllib3