yusef commited on
Commit
29420ff
ยท
1 Parent(s): 3242e70

Replace SigLIP with CLIP (stable zero-shot classifier)

Browse files
Files changed (1) hide show
  1. post_processor.py +39 -40
post_processor.py CHANGED
@@ -58,27 +58,27 @@ def load_mobile_sam():
58
  return None
59
 
60
 
61
- def load_siglip():
62
- """ุชุญู…ูŠู„ SigLIP ู„ู„ู€ Zero-Shot material classification."""
63
  global _siglip_model, _siglip_processor
64
  if _siglip_model is not None:
65
  return _siglip_model, _siglip_processor
66
 
67
  try:
68
- from transformers import SiglipProcessor, SiglipModel
69
 
70
- print("๐Ÿ“ฅ ุชุญู…ูŠู„ SigLIP...")
71
- model_id = "google/siglip-base-patch16-224"
72
- _siglip_processor = SiglipProcessor.from_pretrained(model_id)
73
- _siglip_model = SiglipModel.from_pretrained(
74
  model_id,
75
- torch_dtype=torch.float32, # CPU โ†’ float32 ุฏุงูŠู…ุงู‹
76
  ).to(DEVICE).eval()
77
- print("โœ… SigLIP ุฌุงู‡ุฒ!")
78
  return _siglip_model, _siglip_processor
79
 
80
  except Exception as e:
81
- print(f"โš ๏ธ SigLIP ู…ุด ู…ุชุงุญ: {e}")
82
  return None, None
83
 
84
 
@@ -163,60 +163,59 @@ NUM_BUILDING = len(BUILDING_TEXTS)
163
 
164
 
165
  @torch.no_grad()
166
- def is_building_siglip(
167
  image_rgb: np.ndarray,
168
  mask: np.ndarray,
169
  model,
170
  processor,
171
- threshold: float = 0.4,
172
  ) -> bool:
173
  """
174
- ุจูŠุณุชุฎุฏู… SigLIP Zero-Shot ุนุดุงู† ูŠุชุฃูƒุฏ ุฅู† ุงู„ู€ mask ุฏู‡ ูุนู„ุงู‹ ู…ุจู†ู‰.
175
-
176
- Returns True ู„ูˆ ู…ุจู†ู‰ุŒ False ู„ูˆ ู„ุง (ูŠุชุญุฐู).
177
  """
178
  if model is None:
179
- return True # fallback: ุงู‚ุจู„ ูƒู„ ุญุงุฌุฉ ู„ูˆ SigLIP ู…ุด ุดุบุงู„
180
 
181
  try:
182
- # Crop ุงู„ู€ bounding box ู…ู† ุงู„ุตูˆุฑุฉ
183
  ys, xs = np.where(mask)
184
  if len(ys) == 0:
185
  return False
186
- x1, x2 = max(0, xs.min() - 5), min(image_rgb.shape[1], xs.max() + 5)
187
- y1, y2 = max(0, ys.min() - 5), min(image_rgb.shape[0], ys.max() + 5)
188
  crop = image_rgb[y1:y2, x1:x2]
189
-
190
  if crop.size == 0:
191
  return False
192
 
193
  pil_crop = Image.fromarray(crop)
 
 
 
 
 
 
 
 
 
 
194
 
195
- # ุฌู‡ู‘ุฒ ุงู„ู€ inputs
196
  inputs = processor(
197
- text=ALL_TEXTS,
198
  images=[pil_crop],
199
  return_tensors="pt",
200
- padding="max_length",
201
  )
202
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
203
- if DEVICE == "cuda":
204
- inputs["pixel_values"] = inputs["pixel_values"].half()
205
 
206
- # ุงุญุณุจ ุงู„ู€ similarity scores
207
  outputs = model(**inputs)
208
- logits = outputs.logits_per_image[0] # (num_texts,)
209
- probs = torch.softmax(logits, dim=0).cpu().float().numpy()
210
-
211
- # ู…ุฌู…ูˆุน probability ุงู„ู€ building texts
212
- building_score = probs[:NUM_BUILDING].sum()
213
- non_building_score = probs[NUM_BUILDING:].sum()
214
 
215
- return building_score > threshold
216
 
217
  except Exception as e:
218
- print(f"โš ๏ธ SigLIP check error: {e}")
219
- return True # fallback: ุงู‚ุจู„
220
 
221
 
222
  # ============================================================
@@ -288,8 +287,8 @@ def run_v51_pipeline(
288
  list of dicts: [{"mask": np.array, "score": float, "area_m2": float}]
289
  """
290
  # ุชุญู…ูŠู„ ุงู„ู…ูˆุฏูŠู„ุงุช
291
- sam_predictor = load_mobile_sam() if use_sam else None
292
- siglip_model, siglip_proc = load_siglip() if use_siglip else (None, None)
293
 
294
  all_masks = []
295
  all_scores = []
@@ -303,17 +302,17 @@ def run_v51_pipeline(
303
  print(f" SAM: {len(v5_masks)} โ†’ {len(all_masks)} masks")
304
 
305
  # โ”€โ”€ STEP 2: SigLIP Material Check โ”€๏ฟฝ๏ฟฝ๏ฟฝโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
306
- if use_siglip and siglip_model is not None:
307
  filtered_masks = []
308
  filtered_scores = []
309
  removed = 0
310
  for mask, score in zip(all_masks, all_scores):
311
- if is_building_siglip(image_rgb, mask, siglip_model, siglip_proc, siglip_threshold):
312
  filtered_masks.append(mask)
313
  filtered_scores.append(score)
314
  else:
315
  removed += 1
316
- print(f" SigLIP: ุญุฐู {removed} ุบูŠุฑ ู…ุจุงู†ูŠ")
317
  all_masks, all_scores = filtered_masks, filtered_scores
318
 
319
  # โ”€โ”€ STEP 3: Geometric Rules โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
 
58
  return None
59
 
60
 
61
+ def load_clip():
62
+ """ุชุญู…ูŠู„ CLIP ู„ู„ู€ Zero-Shot material classification (ุจุฏูŠู„ SigLIP - ู…ุณุชู‚ุฑ 100%)."""
63
  global _siglip_model, _siglip_processor
64
  if _siglip_model is not None:
65
  return _siglip_model, _siglip_processor
66
 
67
  try:
68
+ from transformers import CLIPProcessor, CLIPModel
69
 
70
+ print("๐Ÿ“ฅ ุชุญู…ูŠู„ CLIP...")
71
+ model_id = "openai/clip-vit-base-patch32"
72
+ _siglip_processor = CLIPProcessor.from_pretrained(model_id)
73
+ _siglip_model = CLIPModel.from_pretrained(
74
  model_id,
75
+ torch_dtype=torch.float32,
76
  ).to(DEVICE).eval()
77
+ print("โœ… CLIP ุฌุงู‡ุฒ!")
78
  return _siglip_model, _siglip_processor
79
 
80
  except Exception as e:
81
+ print(f"โš ๏ธ CLIP ู…ุด ู…ุชุงุญ: {e}")
82
  return None, None
83
 
84
 
 
163
 
164
 
165
  @torch.no_grad()
166
+ def is_building_clip(
167
  image_rgb: np.ndarray,
168
  mask: np.ndarray,
169
  model,
170
  processor,
171
+ threshold: float = 0.5,
172
  ) -> bool:
173
  """
174
+ CLIP Zero-Shot: ูŠุชุญู‚ู‚ ุฅู† ุงู„ู€ mask ุฏู‡ ู…ุจู†ู‰ ูุนู„ุงู‹.
175
+ Returns True ู„ูˆ ู…ุจู†ู‰ุŒ False ู„ูˆ ู„ุง.
 
176
  """
177
  if model is None:
178
+ return True
179
 
180
  try:
 
181
  ys, xs = np.where(mask)
182
  if len(ys) == 0:
183
  return False
184
+ x1 = max(0, xs.min() - 5); x2 = min(image_rgb.shape[1], xs.max() + 5)
185
+ y1 = max(0, ys.min() - 5); y2 = min(image_rgb.shape[0], ys.max() + 5)
186
  crop = image_rgb[y1:y2, x1:x2]
 
187
  if crop.size == 0:
188
  return False
189
 
190
  pil_crop = Image.fromarray(crop)
191
+ building_texts = [
192
+ "a satellite view of a building rooftop",
193
+ "rooftop of a house seen from above",
194
+ ]
195
+ non_building_texts = [
196
+ "farmland or vegetation from satellite",
197
+ "road or parking lot from above",
198
+ "water or swimming pool from satellite",
199
+ ]
200
+ all_texts = building_texts + non_building_texts
201
 
 
202
  inputs = processor(
203
+ text=all_texts,
204
  images=[pil_crop],
205
  return_tensors="pt",
206
+ padding=True,
207
  )
208
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
 
 
209
 
 
210
  outputs = model(**inputs)
211
+ probs = outputs.logits_per_image[0].softmax(dim=0).cpu().float().numpy()
212
+ building_score = probs[:len(building_texts)].sum()
 
 
 
 
213
 
214
+ return float(building_score) > threshold
215
 
216
  except Exception as e:
217
+ print(f"โš ๏ธ CLIP check error: {e}")
218
+ return True
219
 
220
 
221
  # ============================================================
 
287
  list of dicts: [{"mask": np.array, "score": float, "area_m2": float}]
288
  """
289
  # ุชุญู…ูŠู„ ุงู„ู…ูˆุฏูŠู„ุงุช
290
+ sam_predictor = load_mobile_sam() if use_sam else None
291
+ clip_model, clip_proc = load_clip() if use_siglip else (None, None)
292
 
293
  all_masks = []
294
  all_scores = []
 
302
  print(f" SAM: {len(v5_masks)} โ†’ {len(all_masks)} masks")
303
 
304
  # โ”€โ”€ STEP 2: SigLIP Material Check โ”€๏ฟฝ๏ฟฝ๏ฟฝโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
305
+ if use_siglip and clip_model is not None:
306
  filtered_masks = []
307
  filtered_scores = []
308
  removed = 0
309
  for mask, score in zip(all_masks, all_scores):
310
+ if is_building_clip(image_rgb, mask, clip_model, clip_proc):
311
  filtered_masks.append(mask)
312
  filtered_scores.append(score)
313
  else:
314
  removed += 1
315
+ print(f" CLIP: ุญุฐู {removed} ุบูŠุฑ ู…ุจุงู†ูŠ")
316
  all_masks, all_scores = filtered_masks, filtered_scores
317
 
318
  # โ”€โ”€ STEP 3: Geometric Rules โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€