Johnnyyyyy56 commited on
Commit
deddc3a
·
verified ·
1 Parent(s): 97d3c71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -20
app.py CHANGED
@@ -45,43 +45,53 @@ RELEM_MODEL = load_relem_model()
45
 
46
  # --- 4. Inference Function for Gradio ---
47
  # --- 4. Inference Function for Gradio (REVISED) ---
 
48
  def segment_food(input_image: Image.Image):
49
  """Takes a PIL Image and returns a segmentation mask image."""
50
 
51
  if RELEM_MODEL is None:
52
- # If model failed to load, return the error message
53
- return "Error: Model failed to load. Check logs for details."
 
54
 
55
  try:
56
  # Step 1: Save input image temporarily
57
  temp_path = "/tmp/input_img.png"
58
  input_image.save(temp_path)
 
59
 
60
- # Step 2: Run Inference (Produces the class ID map)
61
  result = inference_segmentor(RELEM_MODEL, temp_path)
62
-
63
- # Step 3: Use the MMSegmentation utility to visualize the result
64
- # NOTE: This requires the 'show_result_pyplot' utility and the model's palette.
65
- # Since we don't have the utility, we will use a simplified NumPy approach that is less prone to memory errors:
66
 
 
67
  seg_mask_array = result[0]
68
 
69
- # --- SIMPLE NUMPY VISUALIZATION (Requires model to define a palette) ---
70
- # MMSegmentation models usually contain a palette definition.
71
- # If the model has a PALETTE property, use it. Otherwise, we convert to a colored PIL Image.
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Assuming RELEM_MODEL has a palette defined (common in MMSegmentation models)
74
- if hasattr(RELEM_MODEL, 'PALETTE'):
75
- # This is a placeholder as the PALETTE logic is complex.
76
- # We'll just return the raw NumPy array which Gradio might handle better:
77
- return Image.fromarray(seg_mask_array.astype(np.uint8))
78
- else:
79
- # Fallback to the simple (but likely black/grayscale) conversion
80
- return Image.fromarray(seg_mask_array.astype(np.uint8)).convert("L")
81
 
82
  except Exception as e:
83
- # This will catch any error during inference and tell you what went wrong.
84
- return f"Inference failed at runtime: {e}"
 
 
 
85
 
86
  # --- 5. GRADIO INTERFACE ---
87
  gr.Interface(
 
45
 
46
  # --- 4. Inference Function for Gradio ---
47
  # --- 4. Inference Function for Gradio (REVISED) ---
48
+ # --- 4. Inference Function for Gradio (ROBUST LOGGING) ---
49
  def segment_food(input_image: Image.Image):
50
  """Takes a PIL Image and returns a segmentation mask image."""
51
 
52
  if RELEM_MODEL is None:
53
+ # If model failed to load at startup, this prints the error
54
+ print("RUNTIME ERROR: RELEM_MODEL is None, failing inference.")
55
+ return "Error: Model failed to load at startup. Check full build logs."
56
 
57
  try:
58
  # Step 1: Save input image temporarily
59
  temp_path = "/tmp/input_img.png"
60
  input_image.save(temp_path)
61
+ print(f"INFO: Saved input image to {temp_path}")
62
 
63
+ # Step 2: Run Inference (This is where the memory/config crash occurs)
64
  result = inference_segmentor(RELEM_MODEL, temp_path)
65
+ print("INFO: Inference completed successfully.")
 
 
 
66
 
67
+ # Step 3: Post-process the result into a COLORFUL image
68
  seg_mask_array = result[0]
69
 
70
+ # --- ROBUST COLOR MASK CREATION ---
71
+ # We use matplotlib to create a visible, colored mask from the raw ID array
72
+ import matplotlib.pyplot as plt
73
+
74
+ fig, ax = plt.subplots(figsize=(input_image.width / 100, input_image.height / 100)) # Sizing helps prevent memory spikes
75
+ ax.imshow(seg_mask_array, cmap='nipy_spectral') # Use a distinct color map
76
+ ax.axis('off')
77
+
78
+ # Save the figure to a buffer (in memory)
79
+ import io
80
+ buf = io.BytesIO()
81
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
82
+ plt.close(fig) # Close the figure immediately to free up memory
83
+ buf.seek(0)
84
 
85
+ # Return the saved image buffer as a PIL Image
86
+ print("INFO: Successfully created color mask.")
87
+ return Image.open(buf)
 
 
 
 
 
88
 
89
  except Exception as e:
90
+ # **THIS CATCHES THE CRASH AND PRINTS IT TO THE LOGS**
91
+ print(f"RUNTIME CRASH: Inference failed with error: {e}")
92
+ import traceback
93
+ traceback.print_exc()
94
+ return f"Inference failed at runtime: {e}. Check logs for traceback."
95
 
96
  # --- 5. GRADIO INTERFACE ---
97
  gr.Interface(