Dramb commited on
Commit
f0403f5
·
verified ·
1 Parent(s): c8ed8e3

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +55 -28
inference.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import Dict
 
2
  import torch
3
  import numpy as np
4
  from PIL import Image
@@ -6,34 +7,60 @@ from skimage import transform
6
  from sam2.build_sam import build_sam2
7
  from sam2.sam2_image_predictor import SAM2ImagePredictor
8
 
9
- class PreTrainedModel:
10
- def __init__(self):
11
- self.model = build_sam2(
12
- "sam2_hiera_t",
13
- "MedSAM2_pretrain_10ep_b1_AMD-SD_sam2_hiera_t.pth",
14
- device="cuda" if torch.cuda.is_available() else "cpu"
15
- )
16
- self.predictor = SAM2ImagePredictor(self.model)
17
-
18
- def __call__(self, inputs: Dict):
19
- image = Image.open(inputs["image"]).convert("RGB")
20
- box = list(map(float, inputs["box"]))
21
-
22
- image_np = np.array(image)
23
- img_3c = image_np if image_np.shape[2] == 3 else np.repeat(image_np[:, :, None], 3, axis=-1)
24
- img_1024 = transform.resize(img_3c, (1024, 1024), preserve_range=True).astype(np.uint8)
25
-
26
- box_1024 = np.array(box) / [image_np.shape[1], image_np.shape[0], image_np.shape[1], image_np.shape[0]] * 1024
27
- box_1024 = box_1024[None, :]
28
-
29
- with torch.inference_mode(), torch.autocast("cuda" if torch.cuda.is_available() else "cpu", dtype=torch.bfloat16):
30
- self.predictor.set_image(img_1024)
31
- masks, _, _ = self.predictor.predict(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  point_coords=None,
33
  point_labels=None,
34
- box=box_1024,
35
- multimask_output=False
36
  )
37
 
38
- mask = masks[0].astype(np.uint8)
39
- return {"mask": mask.tolist()}
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
  import torch
4
  import numpy as np
5
  from PIL import Image
 
7
  from sam2.build_sam import build_sam2
8
  from sam2.sam2_image_predictor import SAM2ImagePredictor
9
 
10
+ # Инициализация модели один раз
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ torch.autocast(device_type=device, dtype=torch.bfloat16).__enter__()
13
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.backends.cudnn.allow_tf32 = True
16
+
17
+ # Загрузи путь к своей модели
18
+ medsam2_model = build_sam2('sam2_hiera_t', 'MedSAM2_pretrain_10ep_b1_AMD-SD_sam2_hiera_t.pth', device=device)
19
+ predictor = SAM2ImagePredictor(medsam2_model)
20
+
21
+ # --- Вспомогательные функции ---
22
+ def decode_image(base64_str):
23
+ img_bytes = base64.b64decode(base64_str)
24
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
25
+ return np.array(img)
26
+
27
+ def encode_mask_to_base64(mask_np):
28
+ mask_img = Image.fromarray((mask_np * 255).astype(np.uint8))
29
+ buffer = io.BytesIO()
30
+ mask_img.save(buffer, format="PNG")
31
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
32
+
33
+ # --- Главная функция вызова инференса ---
34
+ def predict(input_dict):
35
+ try:
36
+ image_b64 = input_dict["image"]
37
+ box = input_dict["box"] # [x1, y1, x2, y2]
38
+
39
+ image = decode_image(image_b64)
40
+
41
+ # Resize image to 1024x1024 for MedSAM2
42
+ img_3c = np.repeat(image[:, :, None], 3, axis=-1) if len(image.shape) == 2 else image
43
+ img_1024 = transform.resize(img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
44
+
45
+ # Normalize box
46
+ box_np = np.array(box)
47
+ box_1024 = box_np / np.array([img_3c.shape[1], img_3c.shape[0], img_3c.shape[1], img_3c.shape[0]]) * 1024
48
+
49
+ with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
50
+ predictor.set_image(img_1024)
51
+ masks, _, _ = predictor.predict(
52
  point_coords=None,
53
  point_labels=None,
54
+ box=np.array([box_1024]),
55
+ multimask_output=False,
56
  )
57
 
58
+ mask = masks[0] # (H, W) np.uint8
59
+
60
+ return {
61
+ "mask": encode_mask_to_base64(mask.astype(np.uint8)),
62
+ "shape": mask.shape,
63
+ }
64
+
65
+ except Exception as e:
66
+ return {"error": str(e)}