Upload app.py
Browse files
app.py
CHANGED
|
@@ -11,8 +11,6 @@ import matplotlib.patches as mpatches
|
|
| 11 |
import io
|
| 12 |
|
| 13 |
# ββ Config (from your training code) ββββββββββββββββββββββββββββββ
|
| 14 |
-
MODEL_REPO = "Thomaslam1202/Cityscapes_Segmentation_Model"
|
| 15 |
-
MODEL_FILE = "best.pth"
|
| 16 |
IMG_SIZE = (512, 512) # your cfg.img_size
|
| 17 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
|
|
@@ -55,27 +53,20 @@ def mask_to_colour(mask: np.ndarray) -> np.ndarray:
|
|
| 55 |
|
| 56 |
# ββ Load model once at startup βββββββββββββββββββββββββββββββββββββ
|
| 57 |
def load_model():
|
| 58 |
-
print(f"Loading on: {DEVICE}")
|
| 59 |
model = SegformerForSemanticSegmentation.from_pretrained(
|
| 60 |
-
"nvidia/mit-b2",
|
| 61 |
-
num_labels=19,
|
| 62 |
id2label=id_to_label,
|
| 63 |
label2id=label_to_id,
|
| 64 |
-
ignore_mismatched_sizes=True,
|
| 65 |
)
|
| 66 |
-
|
| 67 |
-
ckpt = torch.load(model_path, map_location=DEVICE)
|
| 68 |
-
|
| 69 |
-
# β exact same _orig_mod stripping logic from your resume block
|
| 70 |
state_dict = ckpt["state_dict"]
|
| 71 |
if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
|
| 72 |
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
| 73 |
-
|
| 74 |
model.load_state_dict(state_dict)
|
| 75 |
return model.to(DEVICE).eval()
|
| 76 |
|
| 77 |
-
model = load_model()
|
| 78 |
-
|
| 79 |
# ββ Inference (mirrors your predict() function) ββββββββββββββββββββ
|
| 80 |
def run_inference(pil_image: Image.Image):
|
| 81 |
img_tensor = preprocess(pil_image.convert("RGB")).unsqueeze(0).to(DEVICE)
|
|
|
|
| 11 |
import io
|
| 12 |
|
| 13 |
# ββ Config (from your training code) ββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
| 14 |
IMG_SIZE = (512, 512) # your cfg.img_size
|
| 15 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
|
|
|
|
| 53 |
|
| 54 |
# ββ Load model once at startup βββββββββββββββββββββββββββββββββββββ
|
| 55 |
def load_model():
|
|
|
|
| 56 |
model = SegformerForSemanticSegmentation.from_pretrained(
|
| 57 |
+
"nvidia/mit-b2",
|
| 58 |
+
num_labels=19,
|
| 59 |
id2label=id_to_label,
|
| 60 |
label2id=label_to_id,
|
| 61 |
+
ignore_mismatched_sizes=True,
|
| 62 |
)
|
| 63 |
+
ckpt = torch.load("best.pth", map_location=DEVICE)
|
|
|
|
|
|
|
|
|
|
| 64 |
state_dict = ckpt["state_dict"]
|
| 65 |
if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
|
| 66 |
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
|
|
|
| 67 |
model.load_state_dict(state_dict)
|
| 68 |
return model.to(DEVICE).eval()
|
| 69 |
|
|
|
|
|
|
|
| 70 |
# ββ Inference (mirrors your predict() function) ββββββββββββββββββββ
|
| 71 |
def run_inference(pil_image: Image.Image):
|
| 72 |
img_tensor = preprocess(pil_image.convert("RGB")).unsqueeze(0).to(DEVICE)
|