Tremick commited on
Commit
9181198
·
verified ·
1 Parent(s): 7ea3f76

Upload 3 files

Browse files
Files changed (2) hide show
  1. Dockerfile +50 -0
  2. app.py +191 -0
Dockerfile ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use NVIDIA CUDA base image for GPU support
2
+ FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
3
+
4
+ # Set environment variables
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PYTHONDONTWRITEBYTECODE=1 \
7
+ DEBIAN_FRONTEND=noninteractive
8
+
9
+ # Install system dependencies
10
+ RUN apt-get update && apt-get install -y \
11
+ python3-pip \
12
+ python3-dev \
13
+ git \
14
+ wget \
15
+ ffmpeg \
16
+ libsm6 \
17
+ libxext6 \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ # Set working directory
21
+ WORKDIR /app
22
+
23
+ # Upgrade pip
24
+ RUN pip3 install --no-cache-dir --upgrade pip
25
+
26
+ # Install Python dependencies
27
+ # We install torch first to ensure correct CUDA version
28
+ RUN pip3 install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
29
+
30
+ # Copy requirements and install
31
+ COPY requirements.txt .
32
+ RUN pip3 install --no-cache-dir -r requirements.txt
33
+
34
+ # Create a user to run the application (Hugging Face Spaces requirement for security)
35
+ RUN useradd -m -u 1000 user
36
+ USER user
37
+ ENV HOME=/home/user \
38
+ PATH=/home/user/.local/bin:$PATH
39
+
40
+ # Set working directory for the user
41
+ WORKDIR $HOME/app
42
+
43
+ # Copy the rest of the application code
44
+ COPY --chown=user . $HOME/app
45
+
46
+ # Expose the port (Hugging Face Spaces maps port 7860 by default)
47
+ EXPOSE 7860
48
+
49
+ # Command to run the application
50
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import torch
5
+ import numpy as np
6
+ import cv2
7
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel
10
+ from typing import List, Optional, Union
11
+ from PIL import Image
12
+ from transformers import Sam3Processor, Sam3Model
13
+
14
+ app = FastAPI(title="SAM 3 API", description="Segment Anything Model 3 API for HF Spaces")
15
+
16
+ # CORS Setup - Allow all for simplicity in this demo, restrict in production
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ # --- Global Model Variables ---
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ model = None
28
+ processor = None
29
+
30
+ # --- Startup Event ---
31
+ @app.on_event("startup")
32
+ async def startup_event():
33
+ global model, processor
34
+ print(f"Loading SAM 3 Model on {device}...")
35
+ try:
36
+ processor = Sam3Processor.from_pretrained("facebook/sam3")
37
+ model = Sam3Model.from_pretrained("facebook/sam3").to(device)
38
+ print("Model loaded successfully!")
39
+ except Exception as e:
40
+ print(f"Error loading model: {e}")
41
+ # In a real deployed environment, we might want to crash or retry.
42
+ # For now, we print error.
43
+
44
+ # --- Data Models ---
45
+ class Point(BaseModel):
46
+ x: int
47
+ y: int
48
+ label: int # 1 for positive, 0 for negative
49
+
50
+ class Box(BaseModel):
51
+ x1: int
52
+ y1: int
53
+ x2: int
54
+ y2: int
55
+ label: int = 1 # 1 for positive, 0 for negative
56
+
57
+ class InferenceRequest(BaseModel):
58
+ image: str # Base64 encoded image
59
+ prompt_type: str # 'point', 'box', 'text', 'everything'
60
+ points: Optional[List[Point]] = None
61
+ boxes: Optional[List[Box]] = None
62
+ text_prompt: Optional[str] = None
63
+
64
+ # --- Helper Functions ---
65
+ def decode_image(base64_string):
66
+ if "," in base64_string:
67
+ base64_string = base64_string.split(",")[1]
68
+ image_data = base64.b64decode(base64_string)
69
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
70
+ return image
71
+
72
+ def encode_image(image: Image.Image):
73
+ buffered = io.BytesIO()
74
+ image.save(buffered, format="PNG")
75
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
76
+
77
+ def numpy_to_base64_mask(mask_np):
78
+ # mask_np is bool or uint8 (0/1)
79
+ mask_img = Image.fromarray((mask_np * 255).astype(np.uint8))
80
+ return encode_image(mask_img)
81
+
82
+ # --- Endpoints ---
83
+
84
+ @app.get("/")
85
+ def home():
86
+ return {"status": "running", "device": device}
87
+
88
+ @app.post("/predict")
89
+ async def predict(request: InferenceRequest):
90
+ global model, processor
91
+ if not model or not processor:
92
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
93
+
94
+ try:
95
+ image = decode_image(request.image)
96
+ inputs = None
97
+
98
+ # Prepare inputs based on prompt type
99
+ if request.prompt_type == "text":
100
+ if not request.text_prompt:
101
+ raise HTTPException(status_code=400, detail="Text prompt required")
102
+ inputs = processor(images=image, text=request.text_prompt, return_tensors="pt").to(device)
103
+
104
+ elif request.prompt_type == "box":
105
+ if not request.boxes:
106
+ raise HTTPException(status_code=400, detail="Box prompt required")
107
+ # Format: [[ [x1, y1, x2, y2], ... ]] - Batch size 1
108
+ input_boxes = [[[b.x1, b.y1, b.x2, b.y2] for b in request.boxes]]
109
+ input_labels = [[[b.label] for b in request.boxes]]
110
+ inputs = processor(
111
+ images=image,
112
+ input_boxes=input_boxes,
113
+ input_boxes_labels=input_labels,
114
+ return_tensors="pt"
115
+ ).to(device)
116
+
117
+ elif request.prompt_type == "point":
118
+ if not request.points:
119
+ raise HTTPException(status_code=400, detail="Point prompt required")
120
+ # Format: [[ [x, y], ... ]] - Batch size 1
121
+ input_points = [[[p.x, p.y] for p in request.points]]
122
+ input_labels = [[[p.label] for p in request.points]]
123
+ inputs = processor(
124
+ images=image,
125
+ input_points=input_points,
126
+ input_labels=input_labels,
127
+ return_tensors="pt"
128
+ ).to(device)
129
+
130
+ elif request.prompt_type == "everything":
131
+ # For "everything", we might need a different strategy or just use grid points
132
+ # SAM 3 doesn't have a built-in "everything" function in the same way SAM 1 did (AutomaticMaskGenerator)
133
+ # but we can simulate it or check if transformers supports it.
134
+ # For this MVP, let's just return an error or implement a simple grid if possible.
135
+ # Transformers Sam3 integration is new. Let's stick to prompts for now or try a grid of points.
136
+ # We'll use a simple grid of points for now.
137
+ width, height = image.size
138
+ grid_size = 32
139
+ x = np.linspace(0, width, grid_size)
140
+ y = np.linspace(0, height, grid_size)
141
+ xv, yv = np.meshgrid(x, y)
142
+ grid_points = list(zip(xv.flatten(), yv.flatten()))
143
+ input_points = [[list(p) for p in grid_points]]
144
+ input_labels = [[1] * len(grid_points)] # All positive
145
+ # This might just get one big mask or many. Let's try it.
146
+ # Actually, simpler to just say feature not fully supported in this snippet without more complex logic.
147
+ # But let's try sending a generic text prompt "object" or "everything" :D
148
+ # Let's fallback to text "objects".
149
+ inputs = processor(images=image, text="objects", return_tensors="pt").to(device)
150
+
151
+
152
+ else:
153
+ raise HTTPException(status_code=400, detail="Invalid prompt type")
154
+
155
+ # Inference
156
+ with torch.no_grad():
157
+ outputs = model(**inputs)
158
+
159
+ # Post-process
160
+ results = processor.post_process_instance_segmentation(
161
+ outputs,
162
+ threshold=0.5,
163
+ mask_threshold=0.5,
164
+ target_sizes=[image.size[::-1]] # (height, width)
165
+ )[0]
166
+
167
+ # Convert results to JSON-serializable format
168
+ # results['masks'] is a boolean tensor of shape (num_masks, H, W)
169
+ masks = results['masks'].cpu().numpy()
170
+ scores = results['scores'].cpu().numpy().tolist()
171
+ boxes_out = results['boxes'].cpu().numpy().tolist() # [x1, y1, x2, y2]
172
+
173
+ encoded_masks = []
174
+ for mask in masks:
175
+ encoded_masks.append(numpy_to_base64_mask(mask))
176
+
177
+ return {
178
+ "masks": encoded_masks,
179
+ "scores": scores,
180
+ "boxes": boxes_out,
181
+ "count": len(scores)
182
+ }
183
+
184
+ except Exception as e:
185
+ import traceback
186
+ traceback.print_exc()
187
+ raise HTTPException(status_code=500, detail=str(e))
188
+
189
+ if __name__ == "__main__":
190
+ import uvicorn
191
+ uvicorn.run(app, host="0.0.0.0", port=7860)