Update app.py
Browse files
app.py
CHANGED
|
@@ -11,9 +11,11 @@ class_map = ClassMap(['raccoon','banana'])
|
|
| 11 |
|
| 12 |
size = 384
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
|
| 19 |
# load from model_repo:
|
|
@@ -21,7 +23,7 @@ model_2 = models.torchvision.retinanet.model(
|
|
| 21 |
# from huggingface_hub import hf_hub_download
|
| 22 |
# hf_hub_download(repo_id="Alesteba/deep_model_02", filename="retinanet_racoon.pth")
|
| 23 |
|
| 24 |
-
state_dict = torch.load('./
|
| 25 |
|
| 26 |
model_2.load_state_dict(state_dict)
|
| 27 |
|
|
@@ -33,7 +35,7 @@ def predict(img):
|
|
| 33 |
|
| 34 |
img = PIL.Image.fromarray(img, "RGB")
|
| 35 |
|
| 36 |
-
pred_dict_2 =
|
| 37 |
|
| 38 |
img,
|
| 39 |
infer_tfms,
|
|
|
|
| 11 |
|
| 12 |
size = 384
|
| 13 |
|
| 14 |
+
model_type = models.mmdet.retinanet
|
| 15 |
+
|
| 16 |
+
model_2 = model_type.model(
|
| 17 |
+
backbone= model_type.backbones.swin_t_p4_w7_fpn_1x_coco (pretrained=True),
|
| 18 |
+
num_classes=len(class_map)
|
| 19 |
)
|
| 20 |
|
| 21 |
# load from model_repo:
|
|
|
|
| 23 |
# from huggingface_hub import hf_hub_download
|
| 24 |
# hf_hub_download(repo_id="Alesteba/deep_model_02", filename="retinanet_racoon.pth")
|
| 25 |
|
| 26 |
+
state_dict = torch.load('./mmdet_racoon.pth', map_location=torch.device('cpu'))
|
| 27 |
|
| 28 |
model_2.load_state_dict(state_dict)
|
| 29 |
|
|
|
|
| 35 |
|
| 36 |
img = PIL.Image.fromarray(img, "RGB")
|
| 37 |
|
| 38 |
+
pred_dict_2 = model_type.fastai.end2end_detect(
|
| 39 |
|
| 40 |
img,
|
| 41 |
infer_tfms,
|