pillipop commited on
working on sam encoder
Browse files- Makefile +4 -0
- app/cloth_segmentation/model.py +15 -0
- app/server.py +6 -1
- requirements.txt +2 -0
Makefile
CHANGED
|
@@ -1,2 +1,6 @@
|
|
| 1 |
dev:
|
| 2 |
uvicorn app.server:app --reload
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
dev:
|
| 2 |
uvicorn app.server:app --reload
|
| 3 |
+
|
| 4 |
+
download_sam_model:
|
| 5 |
+
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
| 6 |
+
mv sam_vit_h_4b8939.pth sam_model/
|
app/cloth_segmentation/model.py
CHANGED
|
@@ -1,12 +1,18 @@
|
|
|
|
|
|
|
|
| 1 |
from base64 import b64encode
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from io import BytesIO
|
| 4 |
from PIL import Image
|
|
|
|
| 5 |
|
| 6 |
from transformers import pipeline
|
| 7 |
|
| 8 |
pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
@dataclass
|
| 12 |
class Layer:
|
|
@@ -31,3 +37,12 @@ def segment(image: Image) -> [Layer]:
|
|
| 31 |
result.append(Layer(t['label'], image_to_base64(t['mask'])))
|
| 32 |
|
| 33 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append("..")
|
| 3 |
from base64 import b64encode
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from io import BytesIO
|
| 6 |
from PIL import Image
|
| 7 |
+
from segment_anything import sam_model_registry, SamPredictor
|
| 8 |
|
| 9 |
from transformers import pipeline
|
| 10 |
|
| 11 |
pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
|
| 12 |
|
| 13 |
+
sam_checkpoint = "app/sam_models/sam_vit_h_4b8939.pth"
|
| 14 |
+
model_type = "vit_h"
|
| 15 |
+
device = "cuda"
|
| 16 |
|
| 17 |
@dataclass
|
| 18 |
class Layer:
|
|
|
|
| 37 |
result.append(Layer(t['label'], image_to_base64(t['mask'])))
|
| 38 |
|
| 39 |
return result
|
| 40 |
+
|
| 41 |
+
def sam_anything(image: Image) -> [Layer]:
|
| 42 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
| 43 |
+
sam.to(device=device)
|
| 44 |
+
|
| 45 |
+
predictor = SamPredictor(sam)
|
| 46 |
+
pred = predictor.set_image(image)
|
| 47 |
+
print(f"Predicted {len(pred.pred_masks)} instances")
|
| 48 |
+
return pred
|
app/server.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from PIL import Image
|
| 2 |
-
from app.cloth_segmentation.model import Layer, segment
|
| 3 |
from typing import List
|
| 4 |
|
| 5 |
from contextlib import asynccontextmanager
|
|
@@ -32,3 +32,8 @@ def index():
|
|
| 32 |
def mask(upload: UploadFile) -> List[Layer]:
|
| 33 |
image = Image.open(upload.file)
|
| 34 |
return segment(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from PIL import Image
|
| 2 |
+
from app.cloth_segmentation.model import Layer, segment, sam_anything
|
| 3 |
from typing import List
|
| 4 |
|
| 5 |
from contextlib import asynccontextmanager
|
|
|
|
| 32 |
def mask(upload: UploadFile) -> List[Layer]:
|
| 33 |
image = Image.open(upload.file)
|
| 34 |
return segment(image)
|
| 35 |
+
|
| 36 |
+
@app.post("/encode")
|
| 37 |
+
def encode(upload: UploadFile) -> List[Layer]:
|
| 38 |
+
image = Image.open(upload.file)
|
| 39 |
+
return sam_anything(image)
|
requirements.txt
CHANGED
|
@@ -8,3 +8,5 @@ numpy==1.24.3
|
|
| 8 |
torch==2.0.*
|
| 9 |
torchvision
|
| 10 |
transformers==4.30.1
|
|
|
|
|
|
|
|
|
| 8 |
torch==2.0.*
|
| 9 |
torchvision
|
| 10 |
transformers==4.30.1
|
| 11 |
+
|
| 12 |
+
segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
|