V4ldeLund commited on
Commit
ca079c9
·
verified ·
1 Parent(s): 722c3cc

Update code/comments from local workspace

Browse files
Files changed (2) hide show
  1. app.py +4 -6
  2. segmenters/sam3.py +15 -16
app.py CHANGED
@@ -6,7 +6,6 @@ from PIL import Image
6
  import torch
7
  import spaces
8
 
9
- # Disable matplotlib visualizations inside the backend call (Spaces are headless)
10
  import utils.visualize as vis
11
  vis.visualize_segmentation = lambda *args, **kwargs: None # type: ignore
12
 
@@ -14,14 +13,13 @@ from models.model_bank_knn import PatchKNNDetector
14
  from backbones import get_backbone
15
  from segmenters import SAM3Segmenter
16
 
17
- # ZeroGPU: avoid initializing CUDA at import time. Keep everything on CPU until the
18
- # GPU-decorated inference runs and a slice is attached.
19
  DEFAULT_DEVICE = "cpu"
20
 
21
 
22
  @functools.lru_cache(maxsize=1)
23
  def load_backbone(name: str = "dinov3_small"):
24
- # Keep on CPU; move to GPU inside infer when available
25
  return get_backbone(name).to(DEFAULT_DEVICE).eval()
26
 
27
 
@@ -57,7 +55,7 @@ def _make_overlay(rgb_image: np.ndarray, anomaly_map: np.ndarray) -> Image.Image
57
  return Image.fromarray(cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB))
58
 
59
 
60
- @spaces.GPU # When running on ZeroGPU, this grants a short-lived GPU slice for the call.
61
  def infer(ref_files, test_file, use_sam3, sam_prompt):
62
  if not ref_files:
63
  raise gr.Error("Upload at least one reference image.")
@@ -103,7 +101,7 @@ def build_demo():
103
  with gr.Blocks(title="Patch KNN Anomaly Detection") as demo:
104
  gr.Markdown(
105
  "# Patch KNN Anomaly Detection\n"
106
- "Upload reference (normal) images, one test image, and optionally run SAM3 to focus on a specific object."
107
  )
108
 
109
  with gr.Row():
 
6
  import torch
7
  import spaces
8
 
 
9
  import utils.visualize as vis
10
  vis.visualize_segmentation = lambda *args, **kwargs: None # type: ignore
11
 
 
13
  from backbones import get_backbone
14
  from segmenters import SAM3Segmenter
15
 
16
+
 
17
  DEFAULT_DEVICE = "cpu"
18
 
19
 
20
  @functools.lru_cache(maxsize=1)
21
  def load_backbone(name: str = "dinov3_small"):
22
+ # Keep on CPU will move to gpu if available
23
  return get_backbone(name).to(DEFAULT_DEVICE).eval()
24
 
25
 
 
55
  return Image.fromarray(cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB))
56
 
57
 
58
+ @spaces.GPU
59
  def infer(ref_files, test_file, use_sam3, sam_prompt):
60
  if not ref_files:
61
  raise gr.Error("Upload at least one reference image.")
 
101
  with gr.Blocks(title="Patch KNN Anomaly Detection") as demo:
102
  gr.Markdown(
103
  "# Patch KNN Anomaly Detection\n"
104
+ "Upload reference (good) images, one test image, and optionally run SAM3 to segment the specific foreground object."
105
  )
106
 
107
  with gr.Row():
segmenters/sam3.py CHANGED
@@ -3,8 +3,8 @@ from __future__ import annotations
3
  import numpy as np
4
  import torch
5
  from PIL import Image
6
- import os
7
- from transformers import Sam3Processor, Sam3Model
8
 
9
  from segmenters import BaseSegmenter
10
 
@@ -41,20 +41,19 @@ class SAM3Segmenter(BaseSegmenter):
41
  self.score_threshold = score_threshold
42
  self.mask_threshold = mask_threshold
43
 
44
- # Loading model model + defining processor
45
- token = os.getenv("HF_TOKEN")
46
- # facebook/sam3 is a gated model; token required on Spaces without pre-approval
47
- self.model = Sam3Model.from_pretrained(
48
- model_name,
49
- token=token,
50
- trust_remote_code=True,
51
- ).to(self.device)
52
- self.model.eval()
53
- self.processor = Sam3Processor.from_pretrained(
54
- model_name,
55
- token=token,
56
- trust_remote_code=True,
57
- )
58
 
59
  def get_object_mask(self, image: np.ndarray) -> np.ndarray:
60
  """
 
3
  import numpy as np
4
  import torch
5
  from PIL import Image
6
+ import os
7
+ from transformers import Sam3Processor, Sam3Model
8
 
9
  from segmenters import BaseSegmenter
10
 
 
41
  self.score_threshold = score_threshold
42
  self.mask_threshold = mask_threshold
43
 
44
+ # Loading model + defining processor
45
+ token = os.getenv("HF_TOKEN")
46
+ self.model = Sam3Model.from_pretrained(
47
+ model_name,
48
+ token=token,
49
+ trust_remote_code=True,
50
+ ).to(self.device)
51
+ self.model.eval()
52
+ self.processor = Sam3Processor.from_pretrained(
53
+ model_name,
54
+ token=token,
55
+ trust_remote_code=True,
56
+ )
 
57
 
58
  def get_object_mask(self, image: np.ndarray) -> np.ndarray:
59
  """