Johnnyyyyy56's picture
Update app.py
264c98b verified
import gradio as gr
import torch
from PIL import Image
import numpy as np
import subprocess
import sys
import io
import matplotlib.pyplot as plt
import traceback
# --- CRITICAL PATCH: Fix for 'container_abcs' not found in torch._six ---
# This makes older code compatible with PyTorch 1.13.1 by providing the correct import.
try:
from torch._six import container_abcs
except ImportError:
import collections.abc as container_abcs
# --- 0. FORCE INSTALL: Install pre-built mmcv-full for _ext modules ---
try:
print("INFO: Attempting to install pre-built mmcv-full...")
# This installs the mmcv-full wheel pre-built for PyTorch 1.13, which includes the necessary compiled _ext modules.
subprocess.check_call([
sys.executable, '-m', 'pip', 'install',
'mmcv-full==1.7.1',
'-f', 'https://download.openmmlab.com/mmcv/dist/cpu/torch1.13/index.html'
])
print("INFO: Successfully installed pre-built mmcv-full.")
except subprocess.CalledProcessError as e:
print(f"FATAL ERROR: Failed to install pre-built mmcv-full via subprocess. {e}")
# We allow the code to continue execution but the model will likely fail to load later
pass # Continue execution, but model will likely fail to load
# --- 1. Load Custom Model Utilities (Must come after mmcv is installed) ---
try:
from mmseg.apis import init_segmentor, inference_segmentor
except Exception as e:
print(f"FATAL ERROR: Failed to import mmseg utilities: {e}")
# Returning None here will trigger the "Error: Model failed to load" message in the app.
init_segmentor = None
inference_segmentor = None
# --- 2. CONFIGURATION ---
WEIGHTS_PATH = "R50_ReLeM.pth"
CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py"
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# --- 3. Model Loading Function ---
@torch.no_grad()
def load_relem_model():
"""Initializes the segmentation model and loads the pre-trained weights."""
if init_segmentor is None:
return None # Skip if imports failed
try:
model = init_segmentor(
CONFIG_FILE,
checkpoint=WEIGHTS_PATH,
device=DEVICE
)
model.eval()
print(f"ReLeM Model loaded successfully onto {DEVICE}!")
return model
except Exception as e:
print(f"CRITICAL ERROR: Model failed to load weights or config: {e}")
traceback.print_exc()
return None
RELEM_MODEL = load_relem_model()
# --- 4. Inference Function for Gradio ---
def segment_food(input_image: Image.Image):
"""Takes a PIL Image, runs inference, and returns a colorful segmentation mask."""
if RELEM_MODEL is None:
return "Error: Model failed to load at startup. Check build logs for reason."
try:
# Step 1: Save input image temporarily (Required by mmseg's inference pipeline)
temp_path = "/tmp/input_img.png"
input_image.save(temp_path)
# Step 2: Run Inference (Produces the raw class ID map)
result = inference_segmentor(RELEM_MODEL, temp_path)
# Step 3: Post-process the result into a COLORFUL image
seg_mask_array = result[0]
# --- MATPLOTLIB VISUALIZATION (Robust Color Mask) ---
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(seg_mask_array, cmap='nipy_spectral', interpolation='nearest')
ax.axis('off')
# Save the figure to an in-memory buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
except Exception as e:
print(f"RUNTIME CRASH: Inference failed with error: {e}")
traceback.print_exc()
return f"Inference failed at runtime. Error: {e}. Try a smaller image."
# --- 5. GRADIO INTERFACE ---
gr.Interface(
fn=segment_food,
inputs=gr.Image(type="pil", label="Upload Food Image"),
outputs=gr.Image(type="pil", label="ReLeM Segmentation Mask"),
title="ReLeM (FoodSeg103) Deployment Final Attempt",
description="Custom deployment of the ReLeM PyTorch model. Check logs for deployment status.",
allow_flagging="never"
).launch()