File size: 4,286 Bytes
ca8a9df
 
 
 
618d0c1
 
 
 
 
 
264c98b
 
 
 
 
 
618d0c1
264c98b
618d0c1
 
264c98b
618d0c1
 
 
 
 
 
 
 
264c98b
 
 
ca8a9df
264c98b
618d0c1
 
 
 
264c98b
 
 
ca8a9df
 
 
 
618d0c1
 
ca8a9df
 
 
 
 
264c98b
 
 
ca8a9df
 
618d0c1
 
 
ca8a9df
 
618d0c1
ca8a9df
 
618d0c1
 
ca8a9df
 
 
 
 
 
 
618d0c1
 
ca8a9df
264c98b
ca8a9df
 
618d0c1
ca8a9df
 
 
618d0c1
ca8a9df
618d0c1
deddc3a
97d3c71
 
618d0c1
264c98b
 
deddc3a
 
618d0c1
deddc3a
 
264c98b
deddc3a
ca8a9df
deddc3a
ca8a9df
 
deddc3a
 
618d0c1
ca8a9df
 
 
 
 
 
618d0c1
 
ca8a9df
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import torch
from PIL import Image
import numpy as np
import subprocess
import sys
import io
import matplotlib.pyplot as plt
import traceback

# --- CRITICAL PATCH: Fix for 'container_abcs' not found in torch._six ---
# This makes older code compatible with PyTorch 1.13.1 by providing the correct import.
try:
    from torch._six import container_abcs
except ImportError:
    import collections.abc as container_abcs

# --- 0. FORCE INSTALL: Install pre-built mmcv-full for _ext modules ---
try:
    print("INFO: Attempting to install pre-built mmcv-full...")
    # This installs the mmcv-full wheel pre-built for PyTorch 1.13, which includes the necessary compiled _ext modules.
    subprocess.check_call([
        sys.executable, '-m', 'pip', 'install', 
        'mmcv-full==1.7.1', 
        '-f', 'https://download.openmmlab.com/mmcv/dist/cpu/torch1.13/index.html'
    ])
    print("INFO: Successfully installed pre-built mmcv-full.")
except subprocess.CalledProcessError as e:
    print(f"FATAL ERROR: Failed to install pre-built mmcv-full via subprocess. {e}")
    # We allow the code to continue execution but the model will likely fail to load later
    pass # Continue execution, but model will likely fail to load


# --- 1. Load Custom Model Utilities (Must come after mmcv is installed) ---
try:
    from mmseg.apis import init_segmentor, inference_segmentor 
except Exception as e:
    print(f"FATAL ERROR: Failed to import mmseg utilities: {e}")
    # Returning None here will trigger the "Error: Model failed to load" message in the app.
    init_segmentor = None
    inference_segmentor = None


# --- 2. CONFIGURATION ---
WEIGHTS_PATH = "R50_ReLeM.pth"
CONFIG_FILE = "configs/foodnet/SETR_Naive_768x768_80k_base_RM.py" 
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# --- 3. Model Loading Function ---
@torch.no_grad()
def load_relem_model():
    """Initializes the segmentation model and loads the pre-trained weights."""
    if init_segmentor is None:
        return None # Skip if imports failed

    try:
        model = init_segmentor(
            CONFIG_FILE,             
            checkpoint=WEIGHTS_PATH,             
            device=DEVICE
        )
        model.eval()
        print(f"ReLeM Model loaded successfully onto {DEVICE}!")
        return model
    except Exception as e:
        print(f"CRITICAL ERROR: Model failed to load weights or config: {e}")
        traceback.print_exc()
        return None

RELEM_MODEL = load_relem_model()


# --- 4. Inference Function for Gradio ---
def segment_food(input_image: Image.Image):
    """Takes a PIL Image, runs inference, and returns a colorful segmentation mask."""
        
    if RELEM_MODEL is None:
        return "Error: Model failed to load at startup. Check build logs for reason."

    try:
        # Step 1: Save input image temporarily (Required by mmseg's inference pipeline)
        temp_path = "/tmp/input_img.png"
        input_image.save(temp_path)
        
        # Step 2: Run Inference (Produces the raw class ID map)
        result = inference_segmentor(RELEM_MODEL, temp_path)
        
        # Step 3: Post-process the result into a COLORFUL image
        seg_mask_array = result[0]
        
        # --- MATPLOTLIB VISUALIZATION (Robust Color Mask) ---
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.imshow(seg_mask_array, cmap='nipy_spectral', interpolation='nearest')
        ax.axis('off')
        
        # Save the figure to an in-memory buffer
        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
        plt.close(fig)
        buf.seek(0)
        
        return Image.open(buf)
    
    except Exception as e:
        print(f"RUNTIME CRASH: Inference failed with error: {e}")
        traceback.print_exc()
        return f"Inference failed at runtime. Error: {e}. Try a smaller image."

# --- 5. GRADIO INTERFACE ---
gr.Interface(
    fn=segment_food,
    inputs=gr.Image(type="pil", label="Upload Food Image"),
    outputs=gr.Image(type="pil", label="ReLeM Segmentation Mask"),
    title="ReLeM (FoodSeg103) Deployment Final Attempt",
    description="Custom deployment of the ReLeM PyTorch model. Check logs for deployment status.",
    allow_flagging="never"
).launch()