pillipop commited on
Commit
15a589e
·
unverified ·
1 Parent(s): b95e214

working on sam encoder

Browse files
Makefile CHANGED
@@ -1,2 +1,6 @@
1
  dev:
2
  uvicorn app.server:app --reload
 
 
 
 
 
1
  dev:
2
  uvicorn app.server:app --reload
3
+
4
+ download_sam_model:
5
+ wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
6
+ mv sam_vit_h_4b8939.pth sam_model/
app/cloth_segmentation/model.py CHANGED
@@ -1,12 +1,18 @@
 
 
1
  from base64 import b64encode
2
  from dataclasses import dataclass
3
  from io import BytesIO
4
  from PIL import Image
 
5
 
6
  from transformers import pipeline
7
 
8
  pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
9
 
 
 
 
10
 
11
  @dataclass
12
  class Layer:
@@ -31,3 +37,12 @@ def segment(image: Image) -> [Layer]:
31
  result.append(Layer(t['label'], image_to_base64(t['mask'])))
32
 
33
  return result
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
  from base64 import b64encode
4
  from dataclasses import dataclass
5
  from io import BytesIO
6
  from PIL import Image
7
+ from segment_anything import sam_model_registry, SamPredictor
8
 
9
  from transformers import pipeline
10
 
11
  pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
12
 
13
+ sam_checkpoint = "app/sam_models/sam_vit_h_4b8939.pth"
14
+ model_type = "vit_h"
15
+ device = "cuda"
16
 
17
  @dataclass
18
  class Layer:
 
37
  result.append(Layer(t['label'], image_to_base64(t['mask'])))
38
 
39
  return result
40
+
41
+ def sam_anything(image: Image) -> [Layer]:
42
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
43
+ sam.to(device=device)
44
+
45
+ predictor = SamPredictor(sam)
46
+ pred = predictor.set_image(image)
47
+ print(f"Predicted {len(pred.pred_masks)} instances")
48
+ return pred
app/server.py CHANGED
@@ -1,5 +1,5 @@
1
  from PIL import Image
2
- from app.cloth_segmentation.model import Layer, segment
3
  from typing import List
4
 
5
  from contextlib import asynccontextmanager
@@ -32,3 +32,8 @@ def index():
32
  def mask(upload: UploadFile) -> List[Layer]:
33
  image = Image.open(upload.file)
34
  return segment(image)
 
 
 
 
 
 
1
  from PIL import Image
2
+ from app.cloth_segmentation.model import Layer, segment, sam_anything
3
  from typing import List
4
 
5
  from contextlib import asynccontextmanager
 
32
  def mask(upload: UploadFile) -> List[Layer]:
33
  image = Image.open(upload.file)
34
  return segment(image)
35
+
36
+ @app.post("/encode")
37
+ def encode(upload: UploadFile) -> List[Layer]:
38
+ image = Image.open(upload.file)
39
+ return sam_anything(image)
requirements.txt CHANGED
@@ -8,3 +8,5 @@ numpy==1.24.3
8
  torch==2.0.*
9
  torchvision
10
  transformers==4.30.1
 
 
 
8
  torch==2.0.*
9
  torchvision
10
  transformers==4.30.1
11
+
12
+ segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588