Update app.py
Browse filesFix normailzation for SAT images
app.py
CHANGED
|
@@ -82,6 +82,18 @@ DEFAULT_DATASET_LABEL = "LVD-1689M" # initial radio
|
|
| 82 |
DEFAULT_OVERLAY_ALPHA = 0.55
|
| 83 |
DEFAULT_SHOW_GRID = True
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# ---------- Sample image URLs (dependent on dataset) ----------
|
| 86 |
SAMPLE_URL_CHOICES: Dict[str, List[Tuple[str, str]]] = {
|
| 87 |
# LVD: current ones
|
|
@@ -135,17 +147,22 @@ def pad_to_multiple(pil_img: Image.Image, multiple: int = 16) -> Tuple[Image.Ima
|
|
| 135 |
canvas.paste(pil_img, (0, 0))
|
| 136 |
return canvas, (0, 0, W_pad - W, H_pad - H)
|
| 137 |
|
| 138 |
-
def preprocess_no_resize(pil_img: Image.Image, multiple: int = 16):
|
| 139 |
img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
transform = transforms.Compose([
|
| 141 |
transforms.ToTensor(),
|
| 142 |
-
transforms.Normalize(mean=
|
| 143 |
-
std =[0.229, 0.224, 0.225]),
|
| 144 |
])
|
| 145 |
pixel_tensor = transform(img_padded).unsqueeze(0) # (1,3,H,W)
|
| 146 |
disp_np = np.array(img_padded, dtype=np.uint8)
|
| 147 |
return {"pixel_values": pixel_tensor}, disp_np, pad_box
|
| 148 |
|
|
|
|
| 149 |
def upsample_nearest(arr: np.ndarray, H: int, W: int, ps: int) -> np.ndarray:
|
| 150 |
if arr.ndim == 2:
|
| 151 |
return arr.repeat(ps, 0).repeat(ps, 1)
|
|
@@ -198,10 +215,11 @@ def infer_patch_size(model, default: int = 16) -> int:
|
|
| 198 |
|
| 199 |
# ---------- Per-image state ----------
|
| 200 |
class PatchImageState:
|
| 201 |
-
def __init__(self, pil_img: Image.Image, model, device_str: str, ps: int):
|
| 202 |
self.pil = pil_img
|
| 203 |
self.ps = ps
|
| 204 |
-
|
|
|
|
| 205 |
self.disp = disp_np
|
| 206 |
pv = inputs["pixel_values"].to(device_str) # (1,3,H,W)
|
| 207 |
_, _, H, W = pv.shape
|
|
@@ -344,9 +362,12 @@ def init_states(
|
|
| 344 |
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 345 |
model = load_model_cached(full_model_id, device_str)
|
| 346 |
ps = infer_patch_size(model, 16)
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
-
left_state = PatchImageState(left_img, model, device_str, ps) if left_img is not None else None
|
| 349 |
-
right_state = PatchImageState(right_img, model, device_str, ps) if right_img is not None else None
|
| 350 |
|
| 351 |
active_side = 0 if left_state is not None else 1
|
| 352 |
|
|
|
|
| 82 |
DEFAULT_OVERLAY_ALPHA = 0.55
|
| 83 |
DEFAULT_SHOW_GRID = True
|
| 84 |
|
| 85 |
+
# ---------- Normalization presets ----------
|
| 86 |
+
NORMALIZE_STATS = {
|
| 87 |
+
"lvd1689m": {
|
| 88 |
+
"mean": [0.485, 0.456, 0.406],
|
| 89 |
+
"std": [0.229, 0.224, 0.225],
|
| 90 |
+
},
|
| 91 |
+
"sat493m": {
|
| 92 |
+
"mean": [0.430, 0.411, 0.296],
|
| 93 |
+
"std": [0.213, 0.156, 0.143],
|
| 94 |
+
},
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
# ---------- Sample image URLs (dependent on dataset) ----------
|
| 98 |
SAMPLE_URL_CHOICES: Dict[str, List[Tuple[str, str]]] = {
|
| 99 |
# LVD: current ones
|
|
|
|
| 147 |
canvas.paste(pil_img, (0, 0))
|
| 148 |
return canvas, (0, 0, W_pad - W, H_pad - H)
|
| 149 |
|
| 150 |
+
def preprocess_no_resize(pil_img: Image.Image, multiple: int = 16, dataset_key: str = "lvd1689m"):
|
| 151 |
img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
|
| 152 |
+
|
| 153 |
+
# Pick stats based on dataset (default to LVD if unknown)
|
| 154 |
+
stats = NORMALIZE_STATS.get(dataset_key, NORMALIZE_STATS["lvd1689m"])
|
| 155 |
+
mean, std = stats["mean"], stats["std"]
|
| 156 |
+
|
| 157 |
transform = transforms.Compose([
|
| 158 |
transforms.ToTensor(),
|
| 159 |
+
transforms.Normalize(mean=mean, std=std),
|
|
|
|
| 160 |
])
|
| 161 |
pixel_tensor = transform(img_padded).unsqueeze(0) # (1,3,H,W)
|
| 162 |
disp_np = np.array(img_padded, dtype=np.uint8)
|
| 163 |
return {"pixel_values": pixel_tensor}, disp_np, pad_box
|
| 164 |
|
| 165 |
+
|
| 166 |
def upsample_nearest(arr: np.ndarray, H: int, W: int, ps: int) -> np.ndarray:
|
| 167 |
if arr.ndim == 2:
|
| 168 |
return arr.repeat(ps, 0).repeat(ps, 1)
|
|
|
|
| 215 |
|
| 216 |
# ---------- Per-image state ----------
|
| 217 |
class PatchImageState:
|
| 218 |
+
def __init__(self, pil_img: Image.Image, model, device_str: str, ps: int, dataset_key: str):
|
| 219 |
self.pil = pil_img
|
| 220 |
self.ps = ps
|
| 221 |
+
self.dataset_key = dataset_key
|
| 222 |
+
inputs, disp_np, _ = preprocess_no_resize(pil_img, multiple=ps, dataset_key=dataset_key)
|
| 223 |
self.disp = disp_np
|
| 224 |
pv = inputs["pixel_values"].to(device_str) # (1,3,H,W)
|
| 225 |
_, _, H, W = pv.shape
|
|
|
|
| 362 |
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 363 |
model = load_model_cached(full_model_id, device_str)
|
| 364 |
ps = infer_patch_size(model, 16)
|
| 365 |
+
|
| 366 |
+
# Get dataset_key ("lvd1689m" or "sat493m") from the radio label
|
| 367 |
+
dataset_key = dataset_label_to_key(dataset_label)
|
| 368 |
|
| 369 |
+
left_state = PatchImageState(left_img, model, device_str, ps, dataset_key) if left_img is not None else None
|
| 370 |
+
right_state = PatchImageState(right_img, model, device_str, ps, dataset_key) if right_img is not None else None
|
| 371 |
|
| 372 |
active_side = 0 if left_state is not None else 1
|
| 373 |
|