Tremick commited on
Commit
7ea3f76
·
verified ·
1 Parent(s): 4c19a9f

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -191
app.py DELETED
@@ -1,191 +0,0 @@
1
- import os
2
- import io
3
- import base64
4
- import torch
5
- import numpy as np
6
- import cv2
7
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from pydantic import BaseModel
10
- from typing import List, Optional, Union
11
- from PIL import Image
12
- from transformers import Sam3Processor, Sam3Model
13
-
14
- app = FastAPI(title="SAM 3 API", description="Segment Anything Model 3 API for HF Spaces")
15
-
16
- # CORS Setup - Allow all for simplicity in this demo, restrict in production
17
- app.add_middleware(
18
- CORSMiddleware,
19
- allow_origins=["*"],
20
- allow_credentials=True,
21
- allow_methods=["*"],
22
- allow_headers=["*"],
23
- )
24
-
25
- # --- Global Model Variables ---
26
- device = "cuda" if torch.cuda.is_available() else "cpu"
27
- model = None
28
- processor = None
29
-
30
- # --- Startup Event ---
31
- @app.on_event("startup")
32
- async def startup_event():
33
- global model, processor
34
- print(f"Loading SAM 3 Model on {device}...")
35
- try:
36
- processor = Sam3Processor.from_pretrained("facebook/sam3")
37
- model = Sam3Model.from_pretrained("facebook/sam3").to(device)
38
- print("Model loaded successfully!")
39
- except Exception as e:
40
- print(f"Error loading model: {e}")
41
- # In a real deployed environment, we might want to crash or retry.
42
- # For now, we print error.
43
-
44
- # --- Data Models ---
45
- class Point(BaseModel):
46
- x: int
47
- y: int
48
- label: int # 1 for positive, 0 for negative
49
-
50
- class Box(BaseModel):
51
- x1: int
52
- y1: int
53
- x2: int
54
- y2: int
55
- label: int = 1 # 1 for positive, 0 for negative
56
-
57
- class InferenceRequest(BaseModel):
58
- image: str # Base64 encoded image
59
- prompt_type: str # 'point', 'box', 'text', 'everything'
60
- points: Optional[List[Point]] = None
61
- boxes: Optional[List[Box]] = None
62
- text_prompt: Optional[str] = None
63
-
64
- # --- Helper Functions ---
65
- def decode_image(base64_string):
66
- if "," in base64_string:
67
- base64_string = base64_string.split(",")[1]
68
- image_data = base64.b64decode(base64_string)
69
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
70
- return image
71
-
72
- def encode_image(image: Image.Image):
73
- buffered = io.BytesIO()
74
- image.save(buffered, format="PNG")
75
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
76
-
77
- def numpy_to_base64_mask(mask_np):
78
- # mask_np is bool or uint8 (0/1)
79
- mask_img = Image.fromarray((mask_np * 255).astype(np.uint8))
80
- return encode_image(mask_img)
81
-
82
- # --- Endpoints ---
83
-
84
- @app.get("/")
85
- def home():
86
- return {"status": "running", "device": device}
87
-
88
- @app.post("/predict")
89
- async def predict(request: InferenceRequest):
90
- global model, processor
91
- if not model or not processor:
92
- raise HTTPException(status_code=503, detail="Model not loaded yet")
93
-
94
- try:
95
- image = decode_image(request.image)
96
- inputs = None
97
-
98
- # Prepare inputs based on prompt type
99
- if request.prompt_type == "text":
100
- if not request.text_prompt:
101
- raise HTTPException(status_code=400, detail="Text prompt required")
102
- inputs = processor(images=image, text=request.text_prompt, return_tensors="pt").to(device)
103
-
104
- elif request.prompt_type == "box":
105
- if not request.boxes:
106
- raise HTTPException(status_code=400, detail="Box prompt required")
107
- # Format: [[ [x1, y1, x2, y2], ... ]] - Batch size 1
108
- input_boxes = [[[b.x1, b.y1, b.x2, b.y2] for b in request.boxes]]
109
- input_labels = [[[b.label] for b in request.boxes]]
110
- inputs = processor(
111
- images=image,
112
- input_boxes=input_boxes,
113
- input_boxes_labels=input_labels,
114
- return_tensors="pt"
115
- ).to(device)
116
-
117
- elif request.prompt_type == "point":
118
- if not request.points:
119
- raise HTTPException(status_code=400, detail="Point prompt required")
120
- # Format: [[ [x, y], ... ]] - Batch size 1
121
- input_points = [[[p.x, p.y] for p in request.points]]
122
- input_labels = [[[p.label] for p in request.points]]
123
- inputs = processor(
124
- images=image,
125
- input_points=input_points,
126
- input_labels=input_labels,
127
- return_tensors="pt"
128
- ).to(device)
129
-
130
- elif request.prompt_type == "everything":
131
- # For "everything", we might need a different strategy or just use grid points
132
- # SAM 3 doesn't have a built-in "everything" function in the same way SAM 1 did (AutomaticMaskGenerator)
133
- # but we can simulate it or check if transformers supports it.
134
- # For this MVP, let's just return an error or implement a simple grid if possible.
135
- # Transformers Sam3 integration is new. Let's stick to prompts for now or try a grid of points.
136
- # We'll use a simple grid of points for now.
137
- width, height = image.size
138
- grid_size = 32
139
- x = np.linspace(0, width, grid_size)
140
- y = np.linspace(0, height, grid_size)
141
- xv, yv = np.meshgrid(x, y)
142
- grid_points = list(zip(xv.flatten(), yv.flatten()))
143
- input_points = [[list(p) for p in grid_points]]
144
- input_labels = [[1] * len(grid_points)] # All positive
145
- # This might just get one big mask or many. Let's try it.
146
- # Actually, simpler to just say feature not fully supported in this snippet without more complex logic.
147
- # But let's try sending a generic text prompt "object" or "everything" :D
148
- # Let's fallback to text "objects".
149
- inputs = processor(images=image, text="objects", return_tensors="pt").to(device)
150
-
151
-
152
- else:
153
- raise HTTPException(status_code=400, detail="Invalid prompt type")
154
-
155
- # Inference
156
- with torch.no_grad():
157
- outputs = model(**inputs)
158
-
159
- # Post-process
160
- results = processor.post_process_instance_segmentation(
161
- outputs,
162
- threshold=0.5,
163
- mask_threshold=0.5,
164
- target_sizes=[image.size[::-1]] # (height, width)
165
- )[0]
166
-
167
- # Convert results to JSON-serializable format
168
- # results['masks'] is a boolean tensor of shape (num_masks, H, W)
169
- masks = results['masks'].cpu().numpy()
170
- scores = results['scores'].cpu().numpy().tolist()
171
- boxes_out = results['boxes'].cpu().numpy().tolist() # [x1, y1, x2, y2]
172
-
173
- encoded_masks = []
174
- for mask in masks:
175
- encoded_masks.append(numpy_to_base64_mask(mask))
176
-
177
- return {
178
- "masks": encoded_masks,
179
- "scores": scores,
180
- "boxes": boxes_out,
181
- "count": len(scores)
182
- }
183
-
184
- except Exception as e:
185
- import traceback
186
- traceback.print_exc()
187
- raise HTTPException(status_code=500, detail=str(e))
188
-
189
- if __name__ == "__main__":
190
- import uvicorn
191
- uvicorn.run(app, host="0.0.0.0", port=7860)