NG
Initial clean push of functional ReLeM code and weights
ca8a9df
raw
history blame
3.26 kB
import gradio as gr
import torch
from PIL import Image
import numpy as np
# --- 1. Load Custom Model Utilities ---
# NOTE: These imports MUST match the files you copied from the GitHub repo.
# Example imports - adjust these if the model files are deeper in subfolders!
try:
from mmseg.apis import init_segmentor, inference_segmentor # Core MMSeg functions
from mmseg.datasets import build_dataloader, build_dataset # Utilities
# You might also need to copy config files, e.g., to 'configs/relem/'
except ImportError:
print("MMSegmentation utilities not found. Ensure files were copied correctly.")
# --- 2. CONFIGURATION ---
# Define the paths for the files you placed in the repository
WEIGHTS_PATH = "R50_ReLeM.pth"
CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py" # Replace with actual config file from the repo
# --- 3. Model Loading Function ---
@torch.no_grad()
def load_relem_model():
"""Initializes the segmentation model and loads the pre-trained weights."""
try:
# 1. Initialize the segmentor using MMSegmentation's utility
# This requires the config file and the checkpoint path
model = init_segmentor(
CONFIG_FILE,
checkpoint=WEIGHTS_PATH,
device='cuda:0' if torch.cuda.is_available() else 'cpu'
)
model.eval()
print("ReLeM Model loaded successfully!")
return model
except Exception as e:
print(f"Error loading model: {e}")
# Return a flag if loading fails
return None
# Load the model once when the Space starts
RELEM_MODEL = load_relem_model()
# --- 4. Inference Function for Gradio ---
def segment_food(input_image: Image.Image):
"""Takes a PIL Image and returns a segmentation mask image."""
if RELEM_MODEL is None:
return "Error: Model failed to load. Check logs for details."
try:
# Use MMSegmentation's inference pipeline
# The input is usually a filepath, so we need to save and then load
# 1. Save input image temporarily
temp_path = "/tmp/input_img.png"
input_image.save(temp_path)
# 2. Run Inference
result = inference_segmentor(RELEM_MODEL, temp_path)
# 3. Post-process the result (usually a numpy array) into a color mask image
# The result is a segmentation map (array of class IDs).
# We use a simple utility to convert the ID map to a visible color mask.
seg_mask_array = result[0]
color_mask = Image.fromarray(seg_mask_array.astype(np.uint8)).convert("L")
# NOTE: Full color mapping requires the class labels/palette, which you must also copy from the repo.
return color_mask
except Exception as e:
return f"Inference failed: {e}"
# --- 5. GRADIO INTERFACE ---
gr.Interface(
fn=segment_food,
inputs=gr.Image(type="pil", label="Upload Food Image"),
outputs=gr.Image(type="pil", label="ReLeM Segmentation Mask"),
title="ReLeM (FoodSeg103) Segmentation Demo",
description="Custom deployment of the ReLeM PyTorch model. **NOTE:** Model loading requires the full code/config structure from the GitHub repo.",
allow_flagging="never"
).launch()