Johnnyyyyy56 commited on
Commit
618d0c1
·
verified ·
1 Parent(s): 935094a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -41
app.py CHANGED
@@ -2,37 +2,61 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # --- 1. Load Custom Model Utilities ---
7
- # NOTE: These imports MUST match the files you copied from the GitHub repo.
8
- # Example imports - adjust these if the model files are deeper in subfolders!
9
- from mmseg.apis import init_segmentor, inference_segmentor # Core MMSeg functions
10
- from mmseg.datasets import build_dataloader, build_dataset # Utilities
 
 
 
11
 
12
 
13
  # --- 2. CONFIGURATION ---
14
- # Define the paths for the files you placed in the repository
15
  WEIGHTS_PATH = "R50_ReLeM.pth"
16
- CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py" # Replace with actual config file from the repo
 
17
 
18
  # --- 3. Model Loading Function ---
19
  @torch.no_grad()
20
  def load_relem_model():
21
  """Initializes the segmentation model and loads the pre-trained weights."""
22
  try:
23
- # 1. Initialize the segmentor using MMSegmentation's utility
24
- # This requires the config file and the checkpoint path
25
  model = init_segmentor(
26
- CONFIG_FILE,
27
- checkpoint=WEIGHTS_PATH,
28
- device='cuda:0' if torch.cuda.is_available() else 'cpu'
29
  )
30
  model.eval()
31
- print("ReLeM Model loaded successfully!")
32
  return model
33
  except Exception as e:
34
- print(f"Error loading model: {e}")
35
- # Return a flag if loading fails
36
  return None
37
 
38
  # Load the model once when the Space starts
@@ -40,61 +64,51 @@ RELEM_MODEL = load_relem_model()
40
 
41
 
42
  # --- 4. Inference Function for Gradio ---
43
- # --- 4. Inference Function for Gradio (REVISED) ---
44
- # --- 4. Inference Function for Gradio (ROBUST LOGGING) ---
45
  def segment_food(input_image: Image.Image):
46
- """Takes a PIL Image and returns a segmentation mask image."""
47
-
48
  if RELEM_MODEL is None:
49
- # If model failed to load at startup, this prints the error
50
- print("RUNTIME ERROR: RELEM_MODEL is None, failing inference.")
51
- return "Error: Model failed to load at startup. Check full build logs."
52
 
53
  try:
54
- # Step 1: Save input image temporarily
55
  temp_path = "/tmp/input_img.png"
56
  input_image.save(temp_path)
57
- print(f"INFO: Saved input image to {temp_path}")
58
 
59
- # Step 2: Run Inference (This is where the memory/config crash occurs)
 
60
  result = inference_segmentor(RELEM_MODEL, temp_path)
61
- print("INFO: Inference completed successfully.")
62
-
63
  # Step 3: Post-process the result into a COLORFUL image
64
  seg_mask_array = result[0]
65
 
66
- # --- ROBUST COLOR MASK CREATION ---
67
- # We use matplotlib to create a visible, colored mask from the raw ID array
68
- import matplotlib.pyplot as plt
69
-
70
- fig, ax = plt.subplots(figsize=(input_image.width / 100, input_image.height / 100)) # Sizing helps prevent memory spikes
71
- ax.imshow(seg_mask_array, cmap='nipy_spectral') # Use a distinct color map
72
  ax.axis('off')
73
 
74
- # Save the figure to a buffer (in memory)
75
- import io
76
  buf = io.BytesIO()
77
  plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
78
- plt.close(fig) # Close the figure immediately to free up memory
79
  buf.seek(0)
80
 
81
  # Return the saved image buffer as a PIL Image
82
- print("INFO: Successfully created color mask.")
83
  return Image.open(buf)
84
 
85
  except Exception as e:
86
- # **THIS CATCHES THE CRASH AND PRINTS IT TO THE LOGS**
87
  print(f"RUNTIME CRASH: Inference failed with error: {e}")
88
- import traceback
89
  traceback.print_exc()
90
- return f"Inference failed at runtime: {e}. Check logs for traceback."
91
 
92
  # --- 5. GRADIO INTERFACE ---
93
  gr.Interface(
94
  fn=segment_food,
95
  inputs=gr.Image(type="pil", label="Upload Food Image"),
96
  outputs=gr.Image(type="pil", label="ReLeM Segmentation Mask"),
97
- title="ReLeM (FoodSeg103) Segmentation Demo",
98
- description="Custom deployment of the ReLeM PyTorch model. **NOTE:** Model loading requires the full code/config structure from the GitHub repo.",
99
  allow_flagging="never"
100
  ).launch()
 
2
  import torch
3
  from PIL import Image
4
  import numpy as np
5
+ import subprocess
6
+ import sys
7
+ import io
8
+ import matplotlib.pyplot as plt
9
+ import traceback
10
+
11
+ # --- 0. PATCH IN PLACE: Temporarily fix the mmseg version check ---
12
+ # This code block forces the required mmcv-full to be installed correctly.
13
+
14
+ try:
15
+ print("INFO: Attempting to install pre-built mmcv-full...")
16
+ # This installs the mmcv-full wheel pre-built for PyTorch 1.13, which includes the necessary _ext modules.
17
+ subprocess.check_call([
18
+ sys.executable, '-m', 'pip', 'install',
19
+ 'mmcv-full==1.7.1',
20
+ '-f', 'https://download.openmmlab.com/mmcv/dist/cpu/torch1.13/index.html'
21
+ ])
22
+ print("INFO: Successfully installed pre-built mmcv-full.")
23
+ except subprocess.CalledProcessError as e:
24
+ print(f"FATAL ERROR: Failed to install pre-built mmcv-full via subprocess. {e}")
25
+ # Exit if critical dependency fails to install
26
+ sys.exit(1)
27
 
28
  # --- 1. Load Custom Model Utilities ---
29
+ # These imports rely on the code being copied correctly and the mmcv patch working.
30
+ try:
31
+ # These imports should now work because mmcv is installed correctly
32
+ from mmseg.apis import init_segmentor, inference_segmentor
33
+ except Exception as e:
34
+ print(f"FATAL ERROR: Failed to import mmseg utilities: {e}")
35
+ sys.exit(1)
36
 
37
 
38
  # --- 2. CONFIGURATION ---
39
+ # Ensure these paths match your file names and structure
40
  WEIGHTS_PATH = "R50_ReLeM.pth"
41
+ CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py"
42
+ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
43
 
44
  # --- 3. Model Loading Function ---
45
  @torch.no_grad()
46
  def load_relem_model():
47
  """Initializes the segmentation model and loads the pre-trained weights."""
48
  try:
 
 
49
  model = init_segmentor(
50
+ CONFIG_FILE,
51
+ checkpoint=WEIGHTS_PATH,
52
+ device=DEVICE
53
  )
54
  model.eval()
55
+ print(f"ReLeM Model loaded successfully onto {DEVICE}!")
56
  return model
57
  except Exception as e:
58
+ print(f"CRITICAL ERROR: Model failed to load weights or config: {e}")
59
+ traceback.print_exc()
60
  return None
61
 
62
  # Load the model once when the Space starts
 
64
 
65
 
66
  # --- 4. Inference Function for Gradio ---
 
 
67
  def segment_food(input_image: Image.Image):
68
+ """Takes a PIL Image, runs inference, and returns a colorful segmentation mask."""
69
+
70
  if RELEM_MODEL is None:
71
+ return "Error: Model failed to load at startup. Check build logs."
 
 
72
 
73
  try:
74
+ # Step 1: Save input image temporarily (Required by mmseg's inference pipeline)
75
  temp_path = "/tmp/input_img.png"
76
  input_image.save(temp_path)
 
77
 
78
+ # Step 2: Run Inference (Produces the raw class ID map)
79
+ # **This is the point where an OOM (Out of Memory) crash usually happens**
80
  result = inference_segmentor(RELEM_MODEL, temp_path)
81
+
 
82
  # Step 3: Post-process the result into a COLORFUL image
83
  seg_mask_array = result[0]
84
 
85
+ # --- MATPLOTLIB VISUALIZATION (Robust Color Mask) ---
86
+ # Create a new figure to plot the mask with distinct colors
87
+ fig, ax = plt.subplots(figsize=(8, 8)) # Use a moderate size
88
+ ax.imshow(seg_mask_array, cmap='nipy_spectral', interpolation='nearest') # Use colorful colormap
 
 
89
  ax.axis('off')
90
 
91
+ # Save the figure to an in-memory buffer
 
92
  buf = io.BytesIO()
93
  plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
94
+ plt.close(fig) # Free memory
95
  buf.seek(0)
96
 
97
  # Return the saved image buffer as a PIL Image
 
98
  return Image.open(buf)
99
 
100
  except Exception as e:
101
+ # Catches memory errors or other runtime crashes
102
  print(f"RUNTIME CRASH: Inference failed with error: {e}")
 
103
  traceback.print_exc()
104
+ return f"Inference failed at runtime. Error: {e}. Try a smaller image."
105
 
106
  # --- 5. GRADIO INTERFACE ---
107
  gr.Interface(
108
  fn=segment_food,
109
  inputs=gr.Image(type="pil", label="Upload Food Image"),
110
  outputs=gr.Image(type="pil", label="ReLeM Segmentation Mask"),
111
+ title="ReLeM (FoodSeg103) Deployment Final Attempt",
112
+ description="Custom deployment of the ReLeM PyTorch model. Check logs for deployment status.",
113
  allow_flagging="never"
114
  ).launch()