ivanm151 commited on
Commit
3dc4dee
·
1 Parent(s): 56295a6
Files changed (9) hide show
  1. .gitignore +2 -0
  2. Dockerfile +16 -0
  3. app.py +38 -0
  4. models.py +20 -0
  5. requirements.txt +8 -0
  6. utils.py +36 -0
  7. weights/class1.pth +3 -0
  8. weights/class2.pth +3 -0
  9. weights/seg.pth +3 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+ .idea
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ import torch
3
+ from models import load_model1
4
+ from utils import preprocess_image, postprocess_mask, resize_mask, mask_to_base64
5
+ import numpy as np
6
+ from PIL import Image
7
+ import io
8
+
9
+ app = FastAPI()
10
+
11
+ # Загрузка модели при старте (глобально, один раз)
12
+ model1 = load_model1()
13
+
14
+
15
+ @app.get("/")
16
+ def greet_json():
17
+ return {"Hello": "World!"}
18
+
19
+
20
+ @app.post("/predict1")
21
+ async def predict1(file: UploadFile = File(...)):
22
+ content = await file.read()
23
+ image = Image.open(io.BytesIO(content)).convert('RGB')
24
+ image_np = np.array(image)
25
+
26
+ input_tensor = preprocess_image(image_np)
27
+ with torch.no_grad():
28
+ logits = model1(input_tensor.unsqueeze(0)) # batch dim
29
+
30
+ pred_mask = postprocess_mask(logits) # (256, 256) binary
31
+
32
+ mask_100 = resize_mask(pred_mask, 100)
33
+ mask_224 = resize_mask(pred_mask, 224)
34
+
35
+ return {
36
+ "mask_100_base64": mask_to_base64(mask_100),
37
+ "mask_224_base64": mask_to_base64(mask_224)
38
+ }
models.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import segmentation_models_pytorch as smp
3
+
4
+ DEVICE = torch.device('cpu')
5
+
6
+ model1 = None
7
+
8
+ def load_model1(weights_path='weights/seg.pth'):
9
+ global model1
10
+ if model1 is None:
11
+ model1 = smp.Unet(
12
+ encoder_name="mobilenet_v2",
13
+ encoder_weights=None,
14
+ in_channels=3,
15
+ classes=1
16
+ ).to(DEVICE)
17
+ state_dict = torch.load(weights_path, map_location=DEVICE)
18
+ model1.load_state_dict(state_dict)
19
+ model1.eval()
20
+ return model1
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch
4
+ segmentation_models_pytorch
5
+ albumentations
6
+ pillow
7
+ numpy
8
+ opencv-python-headless
utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import albumentations as A
3
+ from albumentations.pytorch import ToTensorV2
4
+ import torch
5
+ import cv2
6
+ from PIL import Image
7
+ import io
8
+ import base64
9
+
10
+ # Препроцессинг: аналог валидации
11
+ preprocess_transform = A.Compose([
12
+ A.Resize(256, 256),
13
+ A.Normalize(),
14
+ ToTensorV2()
15
+ ])
16
+
17
+ def preprocess_image(image_np: np.ndarray) -> torch.Tensor:
18
+ augmented = preprocess_transform(image=image_np)
19
+ return augmented['image']
20
+
21
+ def postprocess_mask(logits: torch.Tensor, threshold: float = 0.5) -> np.ndarray:
22
+ pred = torch.sigmoid(logits).squeeze().cpu().numpy()
23
+ binary_mask = (pred > threshold).astype(np.float32)
24
+ return binary_mask # shape (256, 256)
25
+
26
+ def resize_mask(mask: np.ndarray, size: int) -> np.ndarray:
27
+ # Resize с nearest neighbor для бинарных масок
28
+ resized = cv2.resize(mask, (size, size), interpolation=cv2.INTER_NEAREST)
29
+ return resized.astype(np.float32) # 0/1 float
30
+
31
+ def mask_to_base64(mask: np.ndarray) -> str:
32
+ # Конверт в PIL grayscale (0/255), save as PNG, base64
33
+ pil_mask = Image.fromarray((mask * 255).astype(np.uint8)).convert('L')
34
+ buffered = io.BytesIO()
35
+ pil_mask.save(buffered, format="PNG")
36
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
weights/class1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5767f1246bfe8ee0077a0eefda6c8a1a66e8639de3fc1d94bacf7254633a5f2
3
+ size 9205515
weights/class2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf1835ce202339850361b67f914d7768400f981aa2f0c1a6a29e7f268a749f18
3
+ size 9174411
weights/seg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45c364012ded2dbb389ffe61177056f8b574d5f58822dc3dd7be8a10a20459d7
3
+ size 26805291