pillipop commited on
finish
Browse files- app/cloth_segmentation/model.py +1 -1
- app/segment_anything/model.py +6 -1
- app/server.py +2 -0
- app/simple_segmentation/model.py +1 -13
- app/utils.py +44 -7
app/cloth_segmentation/model.py
CHANGED
|
@@ -9,7 +9,7 @@ from PIL import Image
|
|
| 9 |
from transformers import pipeline
|
| 10 |
|
| 11 |
from app.simple_segmentation.network import U2NET
|
| 12 |
-
from app.utils import image_to_base64
|
| 13 |
|
| 14 |
pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
|
| 15 |
|
|
|
|
| 9 |
from transformers import pipeline
|
| 10 |
|
| 11 |
from app.simple_segmentation.network import U2NET
|
| 12 |
+
from app.utils import image_to_base64
|
| 13 |
|
| 14 |
pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
|
| 15 |
|
app/segment_anything/model.py
CHANGED
|
@@ -5,8 +5,10 @@ from segment_anything import sam_model_registry, SamPredictor
|
|
| 5 |
from base64 import b64encode
|
| 6 |
|
| 7 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 8 |
|
| 9 |
-
|
|
|
|
| 10 |
model_type = "default"
|
| 11 |
|
| 12 |
def initialize_model(model_type, chkpt_path):
|
|
@@ -18,6 +20,9 @@ def initialize_model(model_type, chkpt_path):
|
|
| 18 |
print("GPU not available. Using CPU.")
|
| 19 |
return sam
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
def preprocess_image(image: Image) -> np.ndarray:
|
| 22 |
# Implement any necessary preprocessing steps
|
| 23 |
# Ensure the image is in the expected format
|
|
|
|
| 5 |
from base64 import b64encode
|
| 6 |
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
+
from app.utils import check_or_download_model
|
| 9 |
|
| 10 |
+
MODEL_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
| 11 |
+
chkpt_path = ".checkpoint/sam_vit_h_4b8939.pth"
|
| 12 |
model_type = "default"
|
| 13 |
|
| 14 |
def initialize_model(model_type, chkpt_path):
|
|
|
|
| 20 |
print("GPU not available. Using CPU.")
|
| 21 |
return sam
|
| 22 |
|
| 23 |
+
def load_segment_model(checkpoint_path, device='cpu'):
|
| 24 |
+
check_or_download_model(MODEL_URL, checkpoint_path)
|
| 25 |
+
|
| 26 |
def preprocess_image(image: Image) -> np.ndarray:
|
| 27 |
# Implement any necessary preprocessing steps
|
| 28 |
# Ensure the image is in the expected format
|
app/server.py
CHANGED
|
@@ -10,6 +10,7 @@ from pydantic.dataclasses import dataclass
|
|
| 10 |
from app.cloth_segmentation.model import Layer, segment
|
| 11 |
from app.segment_anything.model import predict
|
| 12 |
from app.simple_segmentation.model import Mode, binary_segment, load_seg_model, Result
|
|
|
|
| 13 |
from app.utils import image_to_base64
|
| 14 |
|
| 15 |
# === Context ===
|
|
@@ -30,6 +31,7 @@ class MaskResponse:
|
|
| 30 |
@asynccontextmanager
|
| 31 |
async def lifespan(app: FastAPI):
|
| 32 |
ml_models["cloth_segmentation"] = load_seg_model(".checkpoint/model.pth")
|
|
|
|
| 33 |
yield
|
| 34 |
ml_models.clear()
|
| 35 |
|
|
|
|
| 10 |
from app.cloth_segmentation.model import Layer, segment
|
| 11 |
from app.segment_anything.model import predict
|
| 12 |
from app.simple_segmentation.model import Mode, binary_segment, load_seg_model, Result
|
| 13 |
+
from app.segment_anything.model import load_segment_model
|
| 14 |
from app.utils import image_to_base64
|
| 15 |
|
| 16 |
# === Context ===
|
|
|
|
| 31 |
@asynccontextmanager
|
| 32 |
async def lifespan(app: FastAPI):
|
| 33 |
ml_models["cloth_segmentation"] = load_seg_model(".checkpoint/model.pth")
|
| 34 |
+
ml_models["segment_anything"] = load_segment_model('.checkpoint/sam_vit_h_4b8939.pth')
|
| 35 |
yield
|
| 36 |
ml_models.clear()
|
| 37 |
|
app/simple_segmentation/model.py
CHANGED
|
@@ -9,7 +9,7 @@ from PIL import Image
|
|
| 9 |
from torchvision import transforms
|
| 10 |
|
| 11 |
from app.simple_segmentation.network import U2NET
|
| 12 |
-
from app.utils import check_or_download_model
|
| 13 |
|
| 14 |
MODEL_URL = "https://huggingface.co/spaces/wildoctopus/cloth-segmentation/resolve/main/model/cloth_segm.pth"
|
| 15 |
|
|
@@ -23,18 +23,6 @@ class Result:
|
|
| 23 |
lower_body: Optional[Image.Image] = None
|
| 24 |
full_body: Optional[Image.Image] = None
|
| 25 |
|
| 26 |
-
def load_checkpoint(model, checkpoint_path):
|
| 27 |
-
model_state_dict = torch.load(
|
| 28 |
-
checkpoint_path, map_location=torch.device("cpu"))
|
| 29 |
-
new_state_dict = OrderedDict()
|
| 30 |
-
for k, v in model_state_dict.items():
|
| 31 |
-
name = k[7:] # remove `module.`
|
| 32 |
-
new_state_dict[name] = v
|
| 33 |
-
|
| 34 |
-
model.load_state_dict(new_state_dict)
|
| 35 |
-
print("----checkpoints loaded from path: {}----".format(checkpoint_path))
|
| 36 |
-
return model
|
| 37 |
-
|
| 38 |
def load_seg_model(checkpoint_path, device='cpu'):
|
| 39 |
net = U2NET(in_ch=3, out_ch=4)
|
| 40 |
check_or_download_model(MODEL_URL, checkpoint_path)
|
|
|
|
| 9 |
from torchvision import transforms
|
| 10 |
|
| 11 |
from app.simple_segmentation.network import U2NET
|
| 12 |
+
from app.utils import check_or_download_model, load_checkpoint
|
| 13 |
|
| 14 |
MODEL_URL = "https://huggingface.co/spaces/wildoctopus/cloth-segmentation/resolve/main/model/cloth_segm.pth"
|
| 15 |
|
|
|
|
| 23 |
lower_body: Optional[Image.Image] = None
|
| 24 |
full_body: Optional[Image.Image] = None
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def load_seg_model(checkpoint_path, device='cpu'):
|
| 27 |
net = U2NET(in_ch=3, out_ch=4)
|
| 28 |
check_or_download_model(MODEL_URL, checkpoint_path)
|
app/utils.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
-
import os
|
| 2 |
from io import BytesIO
|
| 3 |
from base64 import b64encode
|
| 4 |
from urllib.request import urlretrieve
|
| 5 |
from urllib.parse import urlparse
|
| 6 |
from PIL import Image
|
| 7 |
from typing import Optional
|
|
|
|
| 8 |
|
| 9 |
def image_to_base64(image: Image.Image | None):
|
| 10 |
if image == None:
|
|
@@ -19,16 +20,52 @@ def is_valid_url(url):
|
|
| 19 |
return all([result.scheme, result.netloc])
|
| 20 |
except ValueError:
|
| 21 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def check_or_download_model(model_url, file_path):
|
| 24 |
if not is_valid_url(model_url):
|
| 25 |
print("Invalid model URL.")
|
| 26 |
return
|
| 27 |
|
| 28 |
-
if
|
| 29 |
-
print("No model found, downloading model")
|
| 30 |
-
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 31 |
-
urlretrieve(model_url, file_path)
|
| 32 |
-
print("Model downloaded successfully.")
|
| 33 |
-
else:
|
| 34 |
print("Model already exists at:", file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, torch
|
| 2 |
from io import BytesIO
|
| 3 |
from base64 import b64encode
|
| 4 |
from urllib.request import urlretrieve
|
| 5 |
from urllib.parse import urlparse
|
| 6 |
from PIL import Image
|
| 7 |
from typing import Optional
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
|
| 10 |
def image_to_base64(image: Image.Image | None):
|
| 11 |
if image == None:
|
|
|
|
| 20 |
return all([result.scheme, result.netloc])
|
| 21 |
except ValueError:
|
| 22 |
return False
|
| 23 |
+
|
| 24 |
+
def download_with_progress(model_url, file_path):
|
| 25 |
+
try:
|
| 26 |
+
response, _ = urlretrieve(model_url, file_path, reporthook=download_progress)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error downloading the model: {e}")
|
| 29 |
+
return False
|
| 30 |
+
else:
|
| 31 |
+
return True
|
| 32 |
+
|
| 33 |
+
def download_progress(block_num, block_size, total_size):
|
| 34 |
+
progress = min(1.0, block_num * block_size / total_size)
|
| 35 |
+
bar_length = 50
|
| 36 |
+
block = int(round(bar_length * progress))
|
| 37 |
+
progress_percent = progress * 100
|
| 38 |
+
progress_bar = f"[{'=' * block}{' ' * (bar_length - block)}] {progress_percent:.2f}%\r"
|
| 39 |
+
print(progress_bar, end='', flush=True)
|
| 40 |
|
| 41 |
def check_or_download_model(model_url, file_path):
|
| 42 |
if not is_valid_url(model_url):
|
| 43 |
print("Invalid model URL.")
|
| 44 |
return
|
| 45 |
|
| 46 |
+
if os.path.exists(file_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
print("Model already exists at:", file_path)
|
| 48 |
+
else:
|
| 49 |
+
print("No model found, downloading model.")
|
| 50 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 51 |
+
download_with_progress(model_url, file_path)
|
| 52 |
+
print("\nModel downloaded successfully.")
|
| 53 |
+
|
| 54 |
+
def load_checkpoint(model, checkpoint_path):
|
| 55 |
+
# Load model checkpoint
|
| 56 |
+
model_state_dict = torch.load(
|
| 57 |
+
checkpoint_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
| 58 |
+
|
| 59 |
+
# Create a new state dictionary without the 'module.' prefix
|
| 60 |
+
new_state_dict = OrderedDict()
|
| 61 |
+
for k, v in model_state_dict.items():
|
| 62 |
+
name = k[7:] # remove `module.`
|
| 63 |
+
new_state_dict[name] = v
|
| 64 |
+
|
| 65 |
+
# Load the new state dictionary into the model
|
| 66 |
+
model.load_state_dict(new_state_dict)
|
| 67 |
+
|
| 68 |
+
# Print a confirmation message
|
| 69 |
+
print("---- Checkpoint loaded from path: {} ----".format(checkpoint_path))
|
| 70 |
+
|
| 71 |
+
return model
|