File size: 4,286 Bytes
ca8a9df 618d0c1 264c98b 618d0c1 264c98b 618d0c1 264c98b 618d0c1 264c98b ca8a9df 264c98b 618d0c1 264c98b ca8a9df 618d0c1 ca8a9df 264c98b ca8a9df 618d0c1 ca8a9df 618d0c1 ca8a9df 618d0c1 ca8a9df 618d0c1 ca8a9df 264c98b ca8a9df 618d0c1 ca8a9df 618d0c1 ca8a9df 618d0c1 deddc3a 97d3c71 618d0c1 264c98b deddc3a 618d0c1 deddc3a 264c98b deddc3a ca8a9df deddc3a ca8a9df deddc3a 618d0c1 ca8a9df 618d0c1 ca8a9df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import torch
from PIL import Image
import numpy as np
import subprocess
import sys
import io
import matplotlib.pyplot as plt
import traceback
# --- CRITICAL PATCH: Fix for 'container_abcs' not found in torch._six ---
# This makes older code compatible with PyTorch 1.13.1 by providing the correct import.
try:
from torch._six import container_abcs
except ImportError:
import collections.abc as container_abcs
# --- 0. FORCE INSTALL: Install pre-built mmcv-full for _ext modules ---
try:
print("INFO: Attempting to install pre-built mmcv-full...")
# This installs the mmcv-full wheel pre-built for PyTorch 1.13, which includes the necessary compiled _ext modules.
subprocess.check_call([
sys.executable, '-m', 'pip', 'install',
'mmcv-full==1.7.1',
'-f', 'https://download.openmmlab.com/mmcv/dist/cpu/torch1.13/index.html'
])
print("INFO: Successfully installed pre-built mmcv-full.")
except subprocess.CalledProcessError as e:
print(f"FATAL ERROR: Failed to install pre-built mmcv-full via subprocess. {e}")
# We allow the code to continue execution but the model will likely fail to load later
pass # Continue execution, but model will likely fail to load
# --- 1. Load Custom Model Utilities (Must come after mmcv is installed) ---
try:
from mmseg.apis import init_segmentor, inference_segmentor
except Exception as e:
print(f"FATAL ERROR: Failed to import mmseg utilities: {e}")
# Returning None here will trigger the "Error: Model failed to load" message in the app.
init_segmentor = None
inference_segmentor = None
# --- 2. CONFIGURATION ---
WEIGHTS_PATH = "R50_ReLeM.pth"
CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py"
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# --- 3. Model Loading Function ---
@torch.no_grad()
def load_relem_model():
"""Initializes the segmentation model and loads the pre-trained weights."""
if init_segmentor is None:
return None # Skip if imports failed
try:
model = init_segmentor(
CONFIG_FILE,
checkpoint=WEIGHTS_PATH,
device=DEVICE
)
model.eval()
print(f"ReLeM Model loaded successfully onto {DEVICE}!")
return model
except Exception as e:
print(f"CRITICAL ERROR: Model failed to load weights or config: {e}")
traceback.print_exc()
return None
RELEM_MODEL = load_relem_model()
# --- 4. Inference Function for Gradio ---
def segment_food(input_image: Image.Image):
"""Takes a PIL Image, runs inference, and returns a colorful segmentation mask."""
if RELEM_MODEL is None:
return "Error: Model failed to load at startup. Check build logs for reason."
try:
# Step 1: Save input image temporarily (Required by mmseg's inference pipeline)
temp_path = "/tmp/input_img.png"
input_image.save(temp_path)
# Step 2: Run Inference (Produces the raw class ID map)
result = inference_segmentor(RELEM_MODEL, temp_path)
# Step 3: Post-process the result into a COLORFUL image
seg_mask_array = result[0]
# --- MATPLOTLIB VISUALIZATION (Robust Color Mask) ---
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(seg_mask_array, cmap='nipy_spectral', interpolation='nearest')
ax.axis('off')
# Save the figure to an in-memory buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
except Exception as e:
print(f"RUNTIME CRASH: Inference failed with error: {e}")
traceback.print_exc()
return f"Inference failed at runtime. Error: {e}. Try a smaller image."
# --- 5. GRADIO INTERFACE ---
gr.Interface(
fn=segment_food,
inputs=gr.Image(type="pil", label="Upload Food Image"),
outputs=gr.Image(type="pil", label="ReLeM Segmentation Mask"),
title="ReLeM (FoodSeg103) Deployment Final Attempt",
description="Custom deployment of the ReLeM PyTorch model. Check logs for deployment status.",
allow_flagging="never"
).launch() |