import gradio as gr import torch from PIL import Image import numpy as np import subprocess import sys import io import matplotlib.pyplot as plt import traceback # --- CRITICAL PATCH: Fix for 'container_abcs' not found in torch._six --- # This makes older code compatible with PyTorch 1.13.1 by providing the correct import. try: from torch._six import container_abcs except ImportError: import collections.abc as container_abcs # --- 0. FORCE INSTALL: Install pre-built mmcv-full for _ext modules --- try: print("INFO: Attempting to install pre-built mmcv-full...") # This installs the mmcv-full wheel pre-built for PyTorch 1.13, which includes the necessary compiled _ext modules. subprocess.check_call([ sys.executable, '-m', 'pip', 'install', 'mmcv-full==1.7.1', '-f', 'https://download.openmmlab.com/mmcv/dist/cpu/torch1.13/index.html' ]) print("INFO: Successfully installed pre-built mmcv-full.") except subprocess.CalledProcessError as e: print(f"FATAL ERROR: Failed to install pre-built mmcv-full via subprocess. {e}") # We allow the code to continue execution but the model will likely fail to load later pass # Continue execution, but model will likely fail to load # --- 1. Load Custom Model Utilities (Must come after mmcv is installed) --- try: from mmseg.apis import init_segmentor, inference_segmentor except Exception as e: print(f"FATAL ERROR: Failed to import mmseg utilities: {e}") # Returning None here will trigger the "Error: Model failed to load" message in the app. init_segmentor = None inference_segmentor = None # --- 2. CONFIGURATION --- WEIGHTS_PATH = "R50_ReLeM.pth" CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py" DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' # --- 3. Model Loading Function --- @torch.no_grad() def load_relem_model(): """Initializes the segmentation model and loads the pre-trained weights.""" if init_segmentor is None: return None # Skip if imports failed try: model = init_segmentor( CONFIG_FILE, checkpoint=WEIGHTS_PATH, device=DEVICE ) model.eval() print(f"ReLeM Model loaded successfully onto {DEVICE}!") return model except Exception as e: print(f"CRITICAL ERROR: Model failed to load weights or config: {e}") traceback.print_exc() return None RELEM_MODEL = load_relem_model() # --- 4. Inference Function for Gradio --- def segment_food(input_image: Image.Image): """Takes a PIL Image, runs inference, and returns a colorful segmentation mask.""" if RELEM_MODEL is None: return "Error: Model failed to load at startup. Check build logs for reason." try: # Step 1: Save input image temporarily (Required by mmseg's inference pipeline) temp_path = "/tmp/input_img.png" input_image.save(temp_path) # Step 2: Run Inference (Produces the raw class ID map) result = inference_segmentor(RELEM_MODEL, temp_path) # Step 3: Post-process the result into a COLORFUL image seg_mask_array = result[0] # --- MATPLOTLIB VISUALIZATION (Robust Color Mask) --- fig, ax = plt.subplots(figsize=(8, 8)) ax.imshow(seg_mask_array, cmap='nipy_spectral', interpolation='nearest') ax.axis('off') # Save the figure to an in-memory buffer buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close(fig) buf.seek(0) return Image.open(buf) except Exception as e: print(f"RUNTIME CRASH: Inference failed with error: {e}") traceback.print_exc() return f"Inference failed at runtime. Error: {e}. Try a smaller image." # --- 5. GRADIO INTERFACE --- gr.Interface( fn=segment_food, inputs=gr.Image(type="pil", label="Upload Food Image"), outputs=gr.Image(type="pil", label="ReLeM Segmentation Mask"), title="ReLeM (FoodSeg103) Deployment Final Attempt", description="Custom deployment of the ReLeM PyTorch model. Check logs for deployment status.", allow_flagging="never" ).launch()