manu02 commited on
Commit
69fd2f4
·
verified ·
1 Parent(s): ca68995

Update app.py

Browse files

Fix normailzation for SAT images

Files changed (1) hide show
  1. app.py +28 -7
app.py CHANGED
@@ -82,6 +82,18 @@ DEFAULT_DATASET_LABEL = "LVD-1689M" # initial radio
82
  DEFAULT_OVERLAY_ALPHA = 0.55
83
  DEFAULT_SHOW_GRID = True
84
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # ---------- Sample image URLs (dependent on dataset) ----------
86
  SAMPLE_URL_CHOICES: Dict[str, List[Tuple[str, str]]] = {
87
  # LVD: current ones
@@ -135,17 +147,22 @@ def pad_to_multiple(pil_img: Image.Image, multiple: int = 16) -> Tuple[Image.Ima
135
  canvas.paste(pil_img, (0, 0))
136
  return canvas, (0, 0, W_pad - W, H_pad - H)
137
 
138
- def preprocess_no_resize(pil_img: Image.Image, multiple: int = 16):
139
  img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
 
 
 
 
 
140
  transform = transforms.Compose([
141
  transforms.ToTensor(),
142
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
143
- std =[0.229, 0.224, 0.225]),
144
  ])
145
  pixel_tensor = transform(img_padded).unsqueeze(0) # (1,3,H,W)
146
  disp_np = np.array(img_padded, dtype=np.uint8)
147
  return {"pixel_values": pixel_tensor}, disp_np, pad_box
148
 
 
149
  def upsample_nearest(arr: np.ndarray, H: int, W: int, ps: int) -> np.ndarray:
150
  if arr.ndim == 2:
151
  return arr.repeat(ps, 0).repeat(ps, 1)
@@ -198,10 +215,11 @@ def infer_patch_size(model, default: int = 16) -> int:
198
 
199
  # ---------- Per-image state ----------
200
  class PatchImageState:
201
- def __init__(self, pil_img: Image.Image, model, device_str: str, ps: int):
202
  self.pil = pil_img
203
  self.ps = ps
204
- inputs, disp_np, _ = preprocess_no_resize(pil_img, multiple=ps)
 
205
  self.disp = disp_np
206
  pv = inputs["pixel_values"].to(device_str) # (1,3,H,W)
207
  _, _, H, W = pv.shape
@@ -344,9 +362,12 @@ def init_states(
344
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
345
  model = load_model_cached(full_model_id, device_str)
346
  ps = infer_patch_size(model, 16)
 
 
 
347
 
348
- left_state = PatchImageState(left_img, model, device_str, ps) if left_img is not None else None
349
- right_state = PatchImageState(right_img, model, device_str, ps) if right_img is not None else None
350
 
351
  active_side = 0 if left_state is not None else 1
352
 
 
82
  DEFAULT_OVERLAY_ALPHA = 0.55
83
  DEFAULT_SHOW_GRID = True
84
 
85
+ # ---------- Normalization presets ----------
86
+ NORMALIZE_STATS = {
87
+ "lvd1689m": {
88
+ "mean": [0.485, 0.456, 0.406],
89
+ "std": [0.229, 0.224, 0.225],
90
+ },
91
+ "sat493m": {
92
+ "mean": [0.430, 0.411, 0.296],
93
+ "std": [0.213, 0.156, 0.143],
94
+ },
95
+ }
96
+
97
  # ---------- Sample image URLs (dependent on dataset) ----------
98
  SAMPLE_URL_CHOICES: Dict[str, List[Tuple[str, str]]] = {
99
  # LVD: current ones
 
147
  canvas.paste(pil_img, (0, 0))
148
  return canvas, (0, 0, W_pad - W, H_pad - H)
149
 
150
+ def preprocess_no_resize(pil_img: Image.Image, multiple: int = 16, dataset_key: str = "lvd1689m"):
151
  img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
152
+
153
+ # Pick stats based on dataset (default to LVD if unknown)
154
+ stats = NORMALIZE_STATS.get(dataset_key, NORMALIZE_STATS["lvd1689m"])
155
+ mean, std = stats["mean"], stats["std"]
156
+
157
  transform = transforms.Compose([
158
  transforms.ToTensor(),
159
+ transforms.Normalize(mean=mean, std=std),
 
160
  ])
161
  pixel_tensor = transform(img_padded).unsqueeze(0) # (1,3,H,W)
162
  disp_np = np.array(img_padded, dtype=np.uint8)
163
  return {"pixel_values": pixel_tensor}, disp_np, pad_box
164
 
165
+
166
  def upsample_nearest(arr: np.ndarray, H: int, W: int, ps: int) -> np.ndarray:
167
  if arr.ndim == 2:
168
  return arr.repeat(ps, 0).repeat(ps, 1)
 
215
 
216
  # ---------- Per-image state ----------
217
  class PatchImageState:
218
+ def __init__(self, pil_img: Image.Image, model, device_str: str, ps: int, dataset_key: str):
219
  self.pil = pil_img
220
  self.ps = ps
221
+ self.dataset_key = dataset_key
222
+ inputs, disp_np, _ = preprocess_no_resize(pil_img, multiple=ps, dataset_key=dataset_key)
223
  self.disp = disp_np
224
  pv = inputs["pixel_values"].to(device_str) # (1,3,H,W)
225
  _, _, H, W = pv.shape
 
362
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
363
  model = load_model_cached(full_model_id, device_str)
364
  ps = infer_patch_size(model, 16)
365
+
366
+ # Get dataset_key ("lvd1689m" or "sat493m") from the radio label
367
+ dataset_key = dataset_label_to_key(dataset_label)
368
 
369
+ left_state = PatchImageState(left_img, model, device_str, ps, dataset_key) if left_img is not None else None
370
+ right_state = PatchImageState(right_img, model, device_str, ps, dataset_key) if right_img is not None else None
371
 
372
  active_side = 0 if left_state is not None else 1
373