pillipop commited on
Commit
3c5f6c6
·
unverified ·
1 Parent(s): 3c61bf1
app/cloth_segmentation/model.py CHANGED
@@ -9,7 +9,7 @@ from PIL import Image
9
  from transformers import pipeline
10
 
11
  from app.simple_segmentation.network import U2NET
12
- from app.utils import image_to_base64, check_or_download_model
13
 
14
  pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
15
 
 
9
  from transformers import pipeline
10
 
11
  from app.simple_segmentation.network import U2NET
12
+ from app.utils import image_to_base64
13
 
14
  pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
15
 
app/segment_anything/model.py CHANGED
@@ -5,8 +5,10 @@ from segment_anything import sam_model_registry, SamPredictor
5
  from base64 import b64encode
6
 
7
  from huggingface_hub import hf_hub_download
 
8
 
9
- chkpt_path = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_h_4b8939.pth")
 
10
  model_type = "default"
11
 
12
  def initialize_model(model_type, chkpt_path):
@@ -18,6 +20,9 @@ def initialize_model(model_type, chkpt_path):
18
  print("GPU not available. Using CPU.")
19
  return sam
20
 
 
 
 
21
  def preprocess_image(image: Image) -> np.ndarray:
22
  # Implement any necessary preprocessing steps
23
  # Ensure the image is in the expected format
 
5
  from base64 import b64encode
6
 
7
  from huggingface_hub import hf_hub_download
8
+ from app.utils import check_or_download_model
9
 
10
+ MODEL_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
11
+ chkpt_path = ".checkpoint/sam_vit_h_4b8939.pth"
12
  model_type = "default"
13
 
14
  def initialize_model(model_type, chkpt_path):
 
20
  print("GPU not available. Using CPU.")
21
  return sam
22
 
23
+ def load_segment_model(checkpoint_path, device='cpu'):
24
+ check_or_download_model(MODEL_URL, checkpoint_path)
25
+
26
  def preprocess_image(image: Image) -> np.ndarray:
27
  # Implement any necessary preprocessing steps
28
  # Ensure the image is in the expected format
app/server.py CHANGED
@@ -10,6 +10,7 @@ from pydantic.dataclasses import dataclass
10
  from app.cloth_segmentation.model import Layer, segment
11
  from app.segment_anything.model import predict
12
  from app.simple_segmentation.model import Mode, binary_segment, load_seg_model, Result
 
13
  from app.utils import image_to_base64
14
 
15
  # === Context ===
@@ -30,6 +31,7 @@ class MaskResponse:
30
  @asynccontextmanager
31
  async def lifespan(app: FastAPI):
32
  ml_models["cloth_segmentation"] = load_seg_model(".checkpoint/model.pth")
 
33
  yield
34
  ml_models.clear()
35
 
 
10
  from app.cloth_segmentation.model import Layer, segment
11
  from app.segment_anything.model import predict
12
  from app.simple_segmentation.model import Mode, binary_segment, load_seg_model, Result
13
+ from app.segment_anything.model import load_segment_model
14
  from app.utils import image_to_base64
15
 
16
  # === Context ===
 
31
  @asynccontextmanager
32
  async def lifespan(app: FastAPI):
33
  ml_models["cloth_segmentation"] = load_seg_model(".checkpoint/model.pth")
34
+ ml_models["segment_anything"] = load_segment_model('.checkpoint/sam_vit_h_4b8939.pth')
35
  yield
36
  ml_models.clear()
37
 
app/simple_segmentation/model.py CHANGED
@@ -9,7 +9,7 @@ from PIL import Image
9
  from torchvision import transforms
10
 
11
  from app.simple_segmentation.network import U2NET
12
- from app.utils import check_or_download_model
13
 
14
  MODEL_URL = "https://huggingface.co/spaces/wildoctopus/cloth-segmentation/resolve/main/model/cloth_segm.pth"
15
 
@@ -23,18 +23,6 @@ class Result:
23
  lower_body: Optional[Image.Image] = None
24
  full_body: Optional[Image.Image] = None
25
 
26
- def load_checkpoint(model, checkpoint_path):
27
- model_state_dict = torch.load(
28
- checkpoint_path, map_location=torch.device("cpu"))
29
- new_state_dict = OrderedDict()
30
- for k, v in model_state_dict.items():
31
- name = k[7:] # remove `module.`
32
- new_state_dict[name] = v
33
-
34
- model.load_state_dict(new_state_dict)
35
- print("----checkpoints loaded from path: {}----".format(checkpoint_path))
36
- return model
37
-
38
  def load_seg_model(checkpoint_path, device='cpu'):
39
  net = U2NET(in_ch=3, out_ch=4)
40
  check_or_download_model(MODEL_URL, checkpoint_path)
 
9
  from torchvision import transforms
10
 
11
  from app.simple_segmentation.network import U2NET
12
+ from app.utils import check_or_download_model, load_checkpoint
13
 
14
  MODEL_URL = "https://huggingface.co/spaces/wildoctopus/cloth-segmentation/resolve/main/model/cloth_segm.pth"
15
 
 
23
  lower_body: Optional[Image.Image] = None
24
  full_body: Optional[Image.Image] = None
25
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def load_seg_model(checkpoint_path, device='cpu'):
27
  net = U2NET(in_ch=3, out_ch=4)
28
  check_or_download_model(MODEL_URL, checkpoint_path)
app/utils.py CHANGED
@@ -1,10 +1,11 @@
1
- import os
2
  from io import BytesIO
3
  from base64 import b64encode
4
  from urllib.request import urlretrieve
5
  from urllib.parse import urlparse
6
  from PIL import Image
7
  from typing import Optional
 
8
 
9
  def image_to_base64(image: Image.Image | None):
10
  if image == None:
@@ -19,16 +20,52 @@ def is_valid_url(url):
19
  return all([result.scheme, result.netloc])
20
  except ValueError:
21
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def check_or_download_model(model_url, file_path):
24
  if not is_valid_url(model_url):
25
  print("Invalid model URL.")
26
  return
27
 
28
- if not os.path.exists(file_path):
29
- print("No model found, downloading model")
30
- os.makedirs(os.path.dirname(file_path), exist_ok=True)
31
- urlretrieve(model_url, file_path)
32
- print("Model downloaded successfully.")
33
- else:
34
  print("Model already exists at:", file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
  from io import BytesIO
3
  from base64 import b64encode
4
  from urllib.request import urlretrieve
5
  from urllib.parse import urlparse
6
  from PIL import Image
7
  from typing import Optional
8
+ from collections import OrderedDict
9
 
10
  def image_to_base64(image: Image.Image | None):
11
  if image == None:
 
20
  return all([result.scheme, result.netloc])
21
  except ValueError:
22
  return False
23
+
24
+ def download_with_progress(model_url, file_path):
25
+ try:
26
+ response, _ = urlretrieve(model_url, file_path, reporthook=download_progress)
27
+ except Exception as e:
28
+ print(f"Error downloading the model: {e}")
29
+ return False
30
+ else:
31
+ return True
32
+
33
+ def download_progress(block_num, block_size, total_size):
34
+ progress = min(1.0, block_num * block_size / total_size)
35
+ bar_length = 50
36
+ block = int(round(bar_length * progress))
37
+ progress_percent = progress * 100
38
+ progress_bar = f"[{'=' * block}{' ' * (bar_length - block)}] {progress_percent:.2f}%\r"
39
+ print(progress_bar, end='', flush=True)
40
 
41
  def check_or_download_model(model_url, file_path):
42
  if not is_valid_url(model_url):
43
  print("Invalid model URL.")
44
  return
45
 
46
+ if os.path.exists(file_path):
 
 
 
 
 
47
  print("Model already exists at:", file_path)
48
+ else:
49
+ print("No model found, downloading model.")
50
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
51
+ download_with_progress(model_url, file_path)
52
+ print("\nModel downloaded successfully.")
53
+
54
+ def load_checkpoint(model, checkpoint_path):
55
+ # Load model checkpoint
56
+ model_state_dict = torch.load(
57
+ checkpoint_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
58
+
59
+ # Create a new state dictionary without the 'module.' prefix
60
+ new_state_dict = OrderedDict()
61
+ for k, v in model_state_dict.items():
62
+ name = k[7:] # remove `module.`
63
+ new_state_dict[name] = v
64
+
65
+ # Load the new state dictionary into the model
66
+ model.load_state_dict(new_state_dict)
67
+
68
+ # Print a confirmation message
69
+ print("---- Checkpoint loaded from path: {} ----".format(checkpoint_path))
70
+
71
+ return model