Johnnyyyyy56 commited on
Commit
97d3c71
·
verified ·
1 Parent(s): 22d0768

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -44,34 +44,44 @@ RELEM_MODEL = load_relem_model()
44
 
45
 
46
  # --- 4. Inference Function for Gradio ---
 
47
  def segment_food(input_image: Image.Image):
48
  """Takes a PIL Image and returns a segmentation mask image."""
49
 
50
  if RELEM_MODEL is None:
 
51
  return "Error: Model failed to load. Check logs for details."
52
 
53
  try:
54
- # Use MMSegmentation's inference pipeline
55
- # The input is usually a filepath, so we need to save and then load
56
-
57
- # 1. Save input image temporarily
58
  temp_path = "/tmp/input_img.png"
59
  input_image.save(temp_path)
60
 
61
- # 2. Run Inference
62
  result = inference_segmentor(RELEM_MODEL, temp_path)
63
 
64
- # 3. Post-process the result (usually a numpy array) into a color mask image
65
- # The result is a segmentation map (array of class IDs).
66
- # We use a simple utility to convert the ID map to a visible color mask.
67
- seg_mask_array = result[0]
68
- color_mask = Image.fromarray(seg_mask_array.astype(np.uint8)).convert("L")
69
- # NOTE: Full color mapping requires the class labels/palette, which you must also copy from the repo.
 
 
 
70
 
71
- return color_mask
 
 
 
 
 
 
 
72
 
73
  except Exception as e:
74
- return f"Inference failed: {e}"
 
75
 
76
  # --- 5. GRADIO INTERFACE ---
77
  gr.Interface(
 
44
 
45
 
46
  # --- 4. Inference Function for Gradio ---
47
+ # --- 4. Inference Function for Gradio (REVISED) ---
48
  def segment_food(input_image: Image.Image):
49
  """Takes a PIL Image and returns a segmentation mask image."""
50
 
51
  if RELEM_MODEL is None:
52
+ # If model failed to load, return the error message
53
  return "Error: Model failed to load. Check logs for details."
54
 
55
  try:
56
+ # Step 1: Save input image temporarily
 
 
 
57
  temp_path = "/tmp/input_img.png"
58
  input_image.save(temp_path)
59
 
60
+ # Step 2: Run Inference (Produces the class ID map)
61
  result = inference_segmentor(RELEM_MODEL, temp_path)
62
 
63
+ # Step 3: Use the MMSegmentation utility to visualize the result
64
+ # NOTE: This requires the 'show_result_pyplot' utility and the model's palette.
65
+ # Since we don't have the utility, we will use a simplified NumPy approach that is less prone to memory errors:
66
+
67
+ seg_mask_array = result[0]
68
+
69
+ # --- SIMPLE NUMPY VISUALIZATION (Requires model to define a palette) ---
70
+ # MMSegmentation models usually contain a palette definition.
71
+ # If the model has a PALETTE property, use it. Otherwise, we convert to a colored PIL Image.
72
 
73
+ # Assuming RELEM_MODEL has a palette defined (common in MMSegmentation models)
74
+ if hasattr(RELEM_MODEL, 'PALETTE'):
75
+ # This is a placeholder as the PALETTE logic is complex.
76
+ # We'll just return the raw NumPy array which Gradio might handle better:
77
+ return Image.fromarray(seg_mask_array.astype(np.uint8))
78
+ else:
79
+ # Fallback to the simple (but likely black/grayscale) conversion
80
+ return Image.fromarray(seg_mask_array.astype(np.uint8)).convert("L")
81
 
82
  except Exception as e:
83
+ # This will catch any error during inference and tell you what went wrong.
84
+ return f"Inference failed at runtime: {e}"
85
 
86
  # --- 5. GRADIO INTERFACE ---
87
  gr.Interface(