Johnnyyyyy56 commited on
Commit
264c98b
·
verified ·
1 Parent(s): 618d0c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -8,12 +8,17 @@ import io
8
  import matplotlib.pyplot as plt
9
  import traceback
10
 
11
- # --- 0. PATCH IN PLACE: Temporarily fix the mmseg version check ---
12
- # This code block forces the required mmcv-full to be installed correctly.
 
 
 
 
13
 
 
14
  try:
15
  print("INFO: Attempting to install pre-built mmcv-full...")
16
- # This installs the mmcv-full wheel pre-built for PyTorch 1.13, which includes the necessary _ext modules.
17
  subprocess.check_call([
18
  sys.executable, '-m', 'pip', 'install',
19
  'mmcv-full==1.7.1',
@@ -22,21 +27,21 @@ try:
22
  print("INFO: Successfully installed pre-built mmcv-full.")
23
  except subprocess.CalledProcessError as e:
24
  print(f"FATAL ERROR: Failed to install pre-built mmcv-full via subprocess. {e}")
25
- # Exit if critical dependency fails to install
26
- sys.exit(1)
 
27
 
28
- # --- 1. Load Custom Model Utilities ---
29
- # These imports rely on the code being copied correctly and the mmcv patch working.
30
  try:
31
- # These imports should now work because mmcv is installed correctly
32
  from mmseg.apis import init_segmentor, inference_segmentor
33
  except Exception as e:
34
  print(f"FATAL ERROR: Failed to import mmseg utilities: {e}")
35
- sys.exit(1)
 
 
36
 
37
 
38
  # --- 2. CONFIGURATION ---
39
- # Ensure these paths match your file names and structure
40
  WEIGHTS_PATH = "R50_ReLeM.pth"
41
  CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py"
42
  DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
@@ -45,6 +50,9 @@ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
45
  @torch.no_grad()
46
  def load_relem_model():
47
  """Initializes the segmentation model and loads the pre-trained weights."""
 
 
 
48
  try:
49
  model = init_segmentor(
50
  CONFIG_FILE,
@@ -59,7 +67,6 @@ def load_relem_model():
59
  traceback.print_exc()
60
  return None
61
 
62
- # Load the model once when the Space starts
63
  RELEM_MODEL = load_relem_model()
64
 
65
 
@@ -68,7 +75,7 @@ def segment_food(input_image: Image.Image):
68
  """Takes a PIL Image, runs inference, and returns a colorful segmentation mask."""
69
 
70
  if RELEM_MODEL is None:
71
- return "Error: Model failed to load at startup. Check build logs."
72
 
73
  try:
74
  # Step 1: Save input image temporarily (Required by mmseg's inference pipeline)
@@ -76,29 +83,25 @@ def segment_food(input_image: Image.Image):
76
  input_image.save(temp_path)
77
 
78
  # Step 2: Run Inference (Produces the raw class ID map)
79
- # **This is the point where an OOM (Out of Memory) crash usually happens**
80
  result = inference_segmentor(RELEM_MODEL, temp_path)
81
 
82
  # Step 3: Post-process the result into a COLORFUL image
83
  seg_mask_array = result[0]
84
 
85
  # --- MATPLOTLIB VISUALIZATION (Robust Color Mask) ---
86
- # Create a new figure to plot the mask with distinct colors
87
- fig, ax = plt.subplots(figsize=(8, 8)) # Use a moderate size
88
- ax.imshow(seg_mask_array, cmap='nipy_spectral', interpolation='nearest') # Use colorful colormap
89
  ax.axis('off')
90
 
91
  # Save the figure to an in-memory buffer
92
  buf = io.BytesIO()
93
  plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
94
- plt.close(fig) # Free memory
95
  buf.seek(0)
96
 
97
- # Return the saved image buffer as a PIL Image
98
  return Image.open(buf)
99
 
100
  except Exception as e:
101
- # Catches memory errors or other runtime crashes
102
  print(f"RUNTIME CRASH: Inference failed with error: {e}")
103
  traceback.print_exc()
104
  return f"Inference failed at runtime. Error: {e}. Try a smaller image."
 
8
  import matplotlib.pyplot as plt
9
  import traceback
10
 
11
+ # --- CRITICAL PATCH: Fix for 'container_abcs' not found in torch._six ---
12
+ # This makes older code compatible with PyTorch 1.13.1 by providing the correct import.
13
+ try:
14
+ from torch._six import container_abcs
15
+ except ImportError:
16
+ import collections.abc as container_abcs
17
 
18
+ # --- 0. FORCE INSTALL: Install pre-built mmcv-full for _ext modules ---
19
  try:
20
  print("INFO: Attempting to install pre-built mmcv-full...")
21
+ # This installs the mmcv-full wheel pre-built for PyTorch 1.13, which includes the necessary compiled _ext modules.
22
  subprocess.check_call([
23
  sys.executable, '-m', 'pip', 'install',
24
  'mmcv-full==1.7.1',
 
27
  print("INFO: Successfully installed pre-built mmcv-full.")
28
  except subprocess.CalledProcessError as e:
29
  print(f"FATAL ERROR: Failed to install pre-built mmcv-full via subprocess. {e}")
30
+ # We allow the code to continue execution but the model will likely fail to load later
31
+ pass # Continue execution, but model will likely fail to load
32
+
33
 
34
+ # --- 1. Load Custom Model Utilities (Must come after mmcv is installed) ---
 
35
  try:
 
36
  from mmseg.apis import init_segmentor, inference_segmentor
37
  except Exception as e:
38
  print(f"FATAL ERROR: Failed to import mmseg utilities: {e}")
39
+ # Returning None here will trigger the "Error: Model failed to load" message in the app.
40
+ init_segmentor = None
41
+ inference_segmentor = None
42
 
43
 
44
  # --- 2. CONFIGURATION ---
 
45
  WEIGHTS_PATH = "R50_ReLeM.pth"
46
  CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py"
47
  DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
 
50
  @torch.no_grad()
51
  def load_relem_model():
52
  """Initializes the segmentation model and loads the pre-trained weights."""
53
+ if init_segmentor is None:
54
+ return None # Skip if imports failed
55
+
56
  try:
57
  model = init_segmentor(
58
  CONFIG_FILE,
 
67
  traceback.print_exc()
68
  return None
69
 
 
70
  RELEM_MODEL = load_relem_model()
71
 
72
 
 
75
  """Takes a PIL Image, runs inference, and returns a colorful segmentation mask."""
76
 
77
  if RELEM_MODEL is None:
78
+ return "Error: Model failed to load at startup. Check build logs for reason."
79
 
80
  try:
81
  # Step 1: Save input image temporarily (Required by mmseg's inference pipeline)
 
83
  input_image.save(temp_path)
84
 
85
  # Step 2: Run Inference (Produces the raw class ID map)
 
86
  result = inference_segmentor(RELEM_MODEL, temp_path)
87
 
88
  # Step 3: Post-process the result into a COLORFUL image
89
  seg_mask_array = result[0]
90
 
91
  # --- MATPLOTLIB VISUALIZATION (Robust Color Mask) ---
92
+ fig, ax = plt.subplots(figsize=(8, 8))
93
+ ax.imshow(seg_mask_array, cmap='nipy_spectral', interpolation='nearest')
 
94
  ax.axis('off')
95
 
96
  # Save the figure to an in-memory buffer
97
  buf = io.BytesIO()
98
  plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
99
+ plt.close(fig)
100
  buf.seek(0)
101
 
 
102
  return Image.open(buf)
103
 
104
  except Exception as e:
 
105
  print(f"RUNTIME CRASH: Inference failed with error: {e}")
106
  traceback.print_exc()
107
  return f"Inference failed at runtime. Error: {e}. Try a smaller image."