Thomaslam1202 commited on
Commit
0bd8e51
Β·
verified Β·
1 Parent(s): 9c72c46

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -13
app.py CHANGED
@@ -11,8 +11,6 @@ import matplotlib.patches as mpatches
11
  import io
12
 
13
  # ── Config (from your training code) ──────────────────────────────
14
- MODEL_REPO = "Thomaslam1202/Cityscapes_Segmentation_Model"
15
- MODEL_FILE = "best.pth"
16
  IMG_SIZE = (512, 512) # your cfg.img_size
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
 
@@ -55,27 +53,20 @@ def mask_to_colour(mask: np.ndarray) -> np.ndarray:
55
 
56
  # ── Load model once at startup ─────────────────────────────────────
57
  def load_model():
58
- print(f"Loading on: {DEVICE}")
59
  model = SegformerForSemanticSegmentation.from_pretrained(
60
- "nvidia/mit-b2", # same backbone as your training
61
- num_labels=19, # cfg.num_class
62
  id2label=id_to_label,
63
  label2id=label_to_id,
64
- ignore_mismatched_sizes=True, # same flag as your training
65
  )
66
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
67
- ckpt = torch.load(model_path, map_location=DEVICE)
68
-
69
- # ← exact same _orig_mod stripping logic from your resume block
70
  state_dict = ckpt["state_dict"]
71
  if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
72
  state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
73
-
74
  model.load_state_dict(state_dict)
75
  return model.to(DEVICE).eval()
76
 
77
- model = load_model()
78
-
79
  # ── Inference (mirrors your predict() function) ────────────────────
80
  def run_inference(pil_image: Image.Image):
81
  img_tensor = preprocess(pil_image.convert("RGB")).unsqueeze(0).to(DEVICE)
 
11
  import io
12
 
13
  # ── Config (from your training code) ──────────────────────────────
 
 
14
  IMG_SIZE = (512, 512) # your cfg.img_size
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
 
 
53
 
54
  # ── Load model once at startup ─────────────────────────────────────
55
  def load_model():
 
56
  model = SegformerForSemanticSegmentation.from_pretrained(
57
+ "nvidia/mit-b2",
58
+ num_labels=19,
59
  id2label=id_to_label,
60
  label2id=label_to_id,
61
+ ignore_mismatched_sizes=True,
62
  )
63
+ ckpt = torch.load("best.pth", map_location=DEVICE)
 
 
 
64
  state_dict = ckpt["state_dict"]
65
  if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
66
  state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
 
67
  model.load_state_dict(state_dict)
68
  return model.to(DEVICE).eval()
69
 
 
 
70
  # ── Inference (mirrors your predict() function) ────────────────────
71
  def run_inference(pil_image: Image.Image):
72
  img_tensor = preprocess(pil_image.convert("RGB")).unsqueeze(0).to(DEVICE)