Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -44,34 +44,44 @@ RELEM_MODEL = load_relem_model()
|
|
| 44 |
|
| 45 |
|
| 46 |
# --- 4. Inference Function for Gradio ---
|
|
|
|
| 47 |
def segment_food(input_image: Image.Image):
|
| 48 |
"""Takes a PIL Image and returns a segmentation mask image."""
|
| 49 |
|
| 50 |
if RELEM_MODEL is None:
|
|
|
|
| 51 |
return "Error: Model failed to load. Check logs for details."
|
| 52 |
|
| 53 |
try:
|
| 54 |
-
#
|
| 55 |
-
# The input is usually a filepath, so we need to save and then load
|
| 56 |
-
|
| 57 |
-
# 1. Save input image temporarily
|
| 58 |
temp_path = "/tmp/input_img.png"
|
| 59 |
input_image.save(temp_path)
|
| 60 |
|
| 61 |
-
# 2
|
| 62 |
result = inference_segmentor(RELEM_MODEL, temp_path)
|
| 63 |
|
| 64 |
-
# 3
|
| 65 |
-
#
|
| 66 |
-
#
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
except Exception as e:
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
# --- 5. GRADIO INTERFACE ---
|
| 77 |
gr.Interface(
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
# --- 4. Inference Function for Gradio ---
|
| 47 |
+
# --- 4. Inference Function for Gradio (REVISED) ---
|
| 48 |
def segment_food(input_image: Image.Image):
|
| 49 |
"""Takes a PIL Image and returns a segmentation mask image."""
|
| 50 |
|
| 51 |
if RELEM_MODEL is None:
|
| 52 |
+
# If model failed to load, return the error message
|
| 53 |
return "Error: Model failed to load. Check logs for details."
|
| 54 |
|
| 55 |
try:
|
| 56 |
+
# Step 1: Save input image temporarily
|
|
|
|
|
|
|
|
|
|
| 57 |
temp_path = "/tmp/input_img.png"
|
| 58 |
input_image.save(temp_path)
|
| 59 |
|
| 60 |
+
# Step 2: Run Inference (Produces the class ID map)
|
| 61 |
result = inference_segmentor(RELEM_MODEL, temp_path)
|
| 62 |
|
| 63 |
+
# Step 3: Use the MMSegmentation utility to visualize the result
|
| 64 |
+
# NOTE: This requires the 'show_result_pyplot' utility and the model's palette.
|
| 65 |
+
# Since we don't have the utility, we will use a simplified NumPy approach that is less prone to memory errors:
|
| 66 |
+
|
| 67 |
+
seg_mask_array = result[0]
|
| 68 |
+
|
| 69 |
+
# --- SIMPLE NUMPY VISUALIZATION (Requires model to define a palette) ---
|
| 70 |
+
# MMSegmentation models usually contain a palette definition.
|
| 71 |
+
# If the model has a PALETTE property, use it. Otherwise, we convert to a colored PIL Image.
|
| 72 |
|
| 73 |
+
# Assuming RELEM_MODEL has a palette defined (common in MMSegmentation models)
|
| 74 |
+
if hasattr(RELEM_MODEL, 'PALETTE'):
|
| 75 |
+
# This is a placeholder as the PALETTE logic is complex.
|
| 76 |
+
# We'll just return the raw NumPy array which Gradio might handle better:
|
| 77 |
+
return Image.fromarray(seg_mask_array.astype(np.uint8))
|
| 78 |
+
else:
|
| 79 |
+
# Fallback to the simple (but likely black/grayscale) conversion
|
| 80 |
+
return Image.fromarray(seg_mask_array.astype(np.uint8)).convert("L")
|
| 81 |
|
| 82 |
except Exception as e:
|
| 83 |
+
# This will catch any error during inference and tell you what went wrong.
|
| 84 |
+
return f"Inference failed at runtime: {e}"
|
| 85 |
|
| 86 |
# --- 5. GRADIO INTERFACE ---
|
| 87 |
gr.Interface(
|