|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from mmseg.apis import init_segmentor, inference_segmentor |
|
|
from mmseg.datasets import build_dataloader, build_dataset |
|
|
|
|
|
except ImportError: |
|
|
print("MMSegmentation utilities not found. Ensure files were copied correctly.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WEIGHTS_PATH = "R50_ReLeM.pth" |
|
|
CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py" |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def load_relem_model(): |
|
|
"""Initializes the segmentation model and loads the pre-trained weights.""" |
|
|
try: |
|
|
|
|
|
|
|
|
model = init_segmentor( |
|
|
CONFIG_FILE, |
|
|
checkpoint=WEIGHTS_PATH, |
|
|
device='cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
) |
|
|
model.eval() |
|
|
print("ReLeM Model loaded successfully!") |
|
|
return model |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
RELEM_MODEL = load_relem_model() |
|
|
|
|
|
|
|
|
|
|
|
def segment_food(input_image: Image.Image): |
|
|
"""Takes a PIL Image and returns a segmentation mask image.""" |
|
|
|
|
|
if RELEM_MODEL is None: |
|
|
return "Error: Model failed to load. Check logs for details." |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temp_path = "/tmp/input_img.png" |
|
|
input_image.save(temp_path) |
|
|
|
|
|
|
|
|
result = inference_segmentor(RELEM_MODEL, temp_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seg_mask_array = result[0] |
|
|
color_mask = Image.fromarray(seg_mask_array.astype(np.uint8)).convert("L") |
|
|
|
|
|
|
|
|
return color_mask |
|
|
|
|
|
except Exception as e: |
|
|
return f"Inference failed: {e}" |
|
|
|
|
|
|
|
|
gr.Interface( |
|
|
fn=segment_food, |
|
|
inputs=gr.Image(type="pil", label="Upload Food Image"), |
|
|
outputs=gr.Image(type="pil", label="ReLeM Segmentation Mask"), |
|
|
title="ReLeM (FoodSeg103) Segmentation Demo", |
|
|
description="Custom deployment of the ReLeM PyTorch model. **NOTE:** Model loading requires the full code/config structure from the GitHub repo.", |
|
|
allow_flagging="never" |
|
|
).launch() |