|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import subprocess |
|
|
import sys |
|
|
import io |
|
|
import matplotlib.pyplot as plt |
|
|
import traceback |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from torch._six import container_abcs |
|
|
except ImportError: |
|
|
import collections.abc as container_abcs |
|
|
|
|
|
|
|
|
try: |
|
|
print("INFO: Attempting to install pre-built mmcv-full...") |
|
|
|
|
|
subprocess.check_call([ |
|
|
sys.executable, '-m', 'pip', 'install', |
|
|
'mmcv-full==1.7.1', |
|
|
'-f', 'https://download.openmmlab.com/mmcv/dist/cpu/torch1.13/index.html' |
|
|
]) |
|
|
print("INFO: Successfully installed pre-built mmcv-full.") |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"FATAL ERROR: Failed to install pre-built mmcv-full via subprocess. {e}") |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from mmseg.apis import init_segmentor, inference_segmentor |
|
|
except Exception as e: |
|
|
print(f"FATAL ERROR: Failed to import mmseg utilities: {e}") |
|
|
|
|
|
init_segmentor = None |
|
|
inference_segmentor = None |
|
|
|
|
|
|
|
|
|
|
|
WEIGHTS_PATH = "R50_ReLeM.pth" |
|
|
CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py" |
|
|
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def load_relem_model(): |
|
|
"""Initializes the segmentation model and loads the pre-trained weights.""" |
|
|
if init_segmentor is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
model = init_segmentor( |
|
|
CONFIG_FILE, |
|
|
checkpoint=WEIGHTS_PATH, |
|
|
device=DEVICE |
|
|
) |
|
|
model.eval() |
|
|
print(f"ReLeM Model loaded successfully onto {DEVICE}!") |
|
|
return model |
|
|
except Exception as e: |
|
|
print(f"CRITICAL ERROR: Model failed to load weights or config: {e}") |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
RELEM_MODEL = load_relem_model() |
|
|
|
|
|
|
|
|
|
|
|
def segment_food(input_image: Image.Image): |
|
|
"""Takes a PIL Image, runs inference, and returns a colorful segmentation mask.""" |
|
|
|
|
|
if RELEM_MODEL is None: |
|
|
return "Error: Model failed to load at startup. Check build logs for reason." |
|
|
|
|
|
try: |
|
|
|
|
|
temp_path = "/tmp/input_img.png" |
|
|
input_image.save(temp_path) |
|
|
|
|
|
|
|
|
result = inference_segmentor(RELEM_MODEL, temp_path) |
|
|
|
|
|
|
|
|
seg_mask_array = result[0] |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 8)) |
|
|
ax.imshow(seg_mask_array, cmap='nipy_spectral', interpolation='nearest') |
|
|
ax.axis('off') |
|
|
|
|
|
|
|
|
buf = io.BytesIO() |
|
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
|
plt.close(fig) |
|
|
buf.seek(0) |
|
|
|
|
|
return Image.open(buf) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"RUNTIME CRASH: Inference failed with error: {e}") |
|
|
traceback.print_exc() |
|
|
return f"Inference failed at runtime. Error: {e}. Try a smaller image." |
|
|
|
|
|
|
|
|
gr.Interface( |
|
|
fn=segment_food, |
|
|
inputs=gr.Image(type="pil", label="Upload Food Image"), |
|
|
outputs=gr.Image(type="pil", label="ReLeM Segmentation Mask"), |
|
|
title="ReLeM (FoodSeg103) Deployment Final Attempt", |
|
|
description="Custom deployment of the ReLeM PyTorch model. Check logs for deployment status.", |
|
|
allow_flagging="never" |
|
|
).launch() |