File size: 10,893 Bytes
647f69c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d032bfc
647f69c
d032bfc
647f69c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d032bfc
 
 
 
 
 
 
 
647f69c
d032bfc
647f69c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d032bfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647f69c
 
d032bfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647f69c
 
d032bfc
 
 
647f69c
d032bfc
 
647f69c
d032bfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647f69c
d032bfc
 
 
647f69c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""
SAM3 Static Image Segmentation - Correct Implementation

Uses Sam3Model (not Sam3VideoModel) for text-prompted static image segmentation.
"""
import base64
import io
import asyncio
import torch
import numpy as np
from PIL import Image
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoProcessor, AutoModel
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load SAM3 model for STATIC IMAGES
processor = AutoProcessor.from_pretrained("./model", trust_remote_code=True)
model = AutoModel.from_pretrained(
    "./model",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    trust_remote_code=True
)

model.eval()
if torch.cuda.is_available():
    model.cuda()
    logger.info(f"GPU: {torch.cuda.get_device_name()}")
    logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

logger.info(f"✓ Loaded {model.__class__.__name__} for static image segmentation")

# Simple concurrency control
class VRAMManager:
    def __init__(self):
        self.semaphore = asyncio.Semaphore(2)
        self.processing_count = 0

    def get_vram_status(self):
        if not torch.cuda.is_available():
            return {}
        return {
            "total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9,
            "allocated_gb": torch.cuda.memory_allocated() / 1e9,
            "free_gb": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9,
            "processing_now": self.processing_count
        }

    async def acquire(self, rid):
        await self.semaphore.acquire()
        self.processing_count += 1

    def release(self, rid):
        self.processing_count -= 1
        self.semaphore.release()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

vram_manager = VRAMManager()
app = FastAPI(title="SAM3 Static Image API")

class Request(BaseModel):
    inputs: str
    parameters: dict


def run_inference(image_b64: str, classes: list, request_id: str):
    """
    Sam3Model inference for static images with text prompts.

    Uses official SAM3 processor post-processing for correct mask generation.
    """
    try:
        # Decode image
        image_bytes = base64.b64decode(image_b64)
        pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        logger.info(f"[{request_id}] Image: {pil_image.size}, Classes: {classes}")

        # Process with Sam3Processor
        # Sam3Model expects: batch of images matching text prompts
        # For multiple objects in ONE image, repeat the image for each class
        images_batch = [pil_image] * len(classes)
        inputs = processor(
            images=images_batch,  # Repeat image for each text prompt
            text=classes,  # List of text prompts
            return_tensors="pt"
        )

        # Store original sizes for post-processing
        # Format: [[height, width]] for EACH image in batch
        # Since we repeat the image for each class, repeat the size too
        original_size = [pil_image.size[1], pil_image.size[0]]  # [height, width]
        original_sizes = torch.tensor([original_size] * len(classes))
        inputs["original_sizes"] = original_sizes

        logger.info(f"[{request_id}] Processing {len(classes)} classes with batched images")
        logger.info(f"[{request_id}] Original size: {pil_image.size} (W x H)")

        # Move to GPU and match model dtype
        if torch.cuda.is_available():
            model_dtype = next(model.parameters()).dtype
            inputs = {
                k: v.cuda().to(model_dtype) if isinstance(v, torch.Tensor) and v.dtype.is_floating_point else v.cuda() if isinstance(v, torch.Tensor) else v
                for k, v in inputs.items()
            }
            logger.info(f"[{request_id}] Moved inputs to GPU (float tensors to {model_dtype})")

        # Sam3Model Inference
        with torch.no_grad():
            outputs = model(**inputs)
            logger.info(f"[{request_id}] Forward pass successful!")

        logger.info(f"[{request_id}] Output type: {type(outputs)}")

        # Use processor's official post-processing method
        # This handles:
        # - Logit to probability conversion (sigmoid)
        # - Proper thresholding (default 0.5)
        # - Resizing to original image dimensions
        # - Score extraction
        logger.info(f"[{request_id}] Using processor.post_process_instance_segmentation()")

        try:
            processed = processor.post_process_instance_segmentation(
                outputs,
                threshold=0.3,          # Score threshold for detections (lowered to detect road cracks)
                mask_threshold=0.5,     # Probability threshold for mask pixels
                target_sizes=original_sizes.tolist()
            )
            # Returns a LIST of results, one per image in batch (one per class in our case)

            logger.info(f"[{request_id}] Post-processing successful!")
            logger.info(f"[{request_id}] Number of batched results: {len(processed)}")

        except Exception as proc_error:
            logger.error(f"[{request_id}] Post-processing failed: {proc_error}")
            logger.info(f"[{request_id}] Falling back to manual processing")

            # Fallback to manual processing with sigmoid fix
            results = []

            # Extract masks from outputs
            if hasattr(outputs, 'pred_masks'):
                pred_masks = outputs.pred_masks
            elif hasattr(outputs, 'masks'):
                pred_masks = outputs.masks
            elif isinstance(outputs, dict) and 'pred_masks' in outputs:
                pred_masks = outputs['pred_masks']
            else:
                raise ValueError("Cannot find masks in model output")

            logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}")

            for i, cls in enumerate(classes):
                if i < pred_masks.shape[1]:
                    mask_tensor = pred_masks[0, i]

                    # Resize to original size
                    if mask_tensor.shape[-2:] != pil_image.size[::-1]:
                        mask_tensor = torch.nn.functional.interpolate(
                            mask_tensor.unsqueeze(0).unsqueeze(0),
                            size=pil_image.size[::-1],
                            mode='bilinear',
                            align_corners=False
                        ).squeeze()

                    # CRITICAL FIX: Convert logits to probabilities THEN threshold
                    probs = torch.sigmoid(mask_tensor)
                    binary_mask = (probs > 0.5).float().cpu().numpy().astype("uint8") * 255
                else:
                    binary_mask = np.zeros(pil_image.size[::-1], dtype="uint8")

                # Convert to PNG
                pil_mask = Image.fromarray(binary_mask, mode="L")
                buf = io.BytesIO()
                pil_mask.save(buf, format="PNG")
                mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")

                # Extract score
                score = 1.0
                if hasattr(outputs, 'pred_logits') and i < outputs.pred_logits.shape[1]:
                    # Convert logits to probability
                    score = float(torch.sigmoid(outputs.pred_logits[0, i]).cpu())

                results.append({
                    "label": cls,
                    "mask": mask_b64,
                    "score": score
                })

            logger.info(f"[{request_id}] Completed (fallback): {len(results)} masks generated")
            return results

        # Extract results from processor output
        # CRITICAL: processor returns one result dict per class (batched)
        # Each result dict contains MULTIPLE instances of that class
        results = []

        total_instances = 0
        for i, cls in enumerate(classes):
            class_result = processed[i]  # Results for this specific class

            num_instances = len(class_result['masks']) if 'masks' in class_result else 0
            total_instances += num_instances

            if num_instances > 0:
                logger.info(f"[{request_id}] {cls}: {num_instances} instance(s) detected")

                # Loop through ALL instances of this class
                for j in range(num_instances):
                    # Get mask (already binary, resized to original size)
                    mask_np = class_result['masks'][j].cpu().numpy().astype("uint8") * 255

                    # Convert to PNG
                    pil_mask = Image.fromarray(mask_np, mode="L")
                    buf = io.BytesIO()
                    pil_mask.save(buf, format="PNG")
                    mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")

                    # Get score (already converted to probability by processor)
                    score = float(class_result['scores'][j]) if 'scores' in class_result else 1.0

                    # Calculate coverage for logging
                    coverage = (mask_np > 0).sum() / mask_np.size * 100

                    results.append({
                        "label": cls,
                        "mask": mask_b64,
                        "score": score,
                        "instance_id": j
                    })

                    logger.info(f"[{request_id}]   └─ Instance {j}: score={score:.3f}, coverage={coverage:.2f}%")
            else:
                logger.info(f"[{request_id}] {cls}: No instances detected")

        logger.info(f"[{request_id}] Completed: {total_instances} instance(s) across {len(classes)} class(es)")
        return results

    except Exception as e:
        logger.error(f"[{request_id}] Failed: {str(e)}")
        import traceback
        traceback.print_exc()
        raise


@app.post("/")
async def predict(req: Request):
    request_id = str(id(req))[:8]
    try:
        await vram_manager.acquire(request_id)
        try:
            results = await asyncio.to_thread(
                run_inference,
                req.inputs,
                req.parameters.get("classes", []),
                request_id
            )
            return results
        finally:
            vram_manager.release(request_id)
    except Exception as e:
        logger.error(f"[{request_id}] Error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/health")
async def health():
    return {
        "status": "healthy",
        "model": model.__class__.__name__,
        "gpu_available": torch.cuda.is_available(),
        "vram": vram_manager.get_vram_status()
    }


@app.get("/metrics")
async def metrics():
    return vram_manager.get_vram_status()


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)