Spaces:
Sleeping
Sleeping
HERIUN
commited on
Commit
·
1081f7c
1
Parent(s):
ba5f22e
add models
Browse files- models/DocScanner/inference.py +4 -4
- rect_main.py +2 -2
models/DocScanner/inference.py
CHANGED
|
@@ -37,12 +37,12 @@ class Net(nn.Module):
|
|
| 37 |
return bm, msk
|
| 38 |
|
| 39 |
|
| 40 |
-
def reload_seg_model(model, path=""):
|
| 41 |
if not bool(path):
|
| 42 |
return model
|
| 43 |
else:
|
| 44 |
model_dict = model.state_dict()
|
| 45 |
-
pretrained_dict = torch.load(path, map_location=
|
| 46 |
pretrained_dict = {
|
| 47 |
k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict
|
| 48 |
}
|
|
@@ -52,12 +52,12 @@ def reload_seg_model(model, path=""):
|
|
| 52 |
return model
|
| 53 |
|
| 54 |
|
| 55 |
-
def reload_rec_model(model, path=""):
|
| 56 |
if not bool(path):
|
| 57 |
return model
|
| 58 |
else:
|
| 59 |
model_dict = model.state_dict()
|
| 60 |
-
pretrained_dict = torch.load(path, map_location=
|
| 61 |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
| 62 |
model_dict.update(pretrained_dict)
|
| 63 |
model.load_state_dict(model_dict)
|
|
|
|
| 37 |
return bm, msk
|
| 38 |
|
| 39 |
|
| 40 |
+
def reload_seg_model(cuda, model, path=""):
|
| 41 |
if not bool(path):
|
| 42 |
return model
|
| 43 |
else:
|
| 44 |
model_dict = model.state_dict()
|
| 45 |
+
pretrained_dict = torch.load(path, map_location=cuda)
|
| 46 |
pretrained_dict = {
|
| 47 |
k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict
|
| 48 |
}
|
|
|
|
| 52 |
return model
|
| 53 |
|
| 54 |
|
| 55 |
+
def reload_rec_model(cuda, model, path=""):
|
| 56 |
if not bool(path):
|
| 57 |
return model
|
| 58 |
else:
|
| 59 |
model_dict = model.state_dict()
|
| 60 |
+
pretrained_dict = torch.load(path, map_location=cuda)
|
| 61 |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
| 62 |
model_dict.update(pretrained_dict)
|
| 63 |
model.load_state_dict(model_dict)
|
rect_main.py
CHANGED
|
@@ -33,8 +33,8 @@ def load_geotrp_model(cuda, path=""):
|
|
| 33 |
def load_docscanner_model(cuda, path_l="", path_m=""):
|
| 34 |
|
| 35 |
net = DocScanner.Net().to(cuda)
|
| 36 |
-
DocScanner.reload_seg_model(net.msk, path_m)
|
| 37 |
-
DocScanner.reload_rec_model(net.bm, path_l)
|
| 38 |
net.eval()
|
| 39 |
|
| 40 |
return net
|
|
|
|
| 33 |
def load_docscanner_model(cuda, path_l="", path_m=""):
|
| 34 |
|
| 35 |
net = DocScanner.Net().to(cuda)
|
| 36 |
+
DocScanner.reload_seg_model(cuda, net.msk, path_m)
|
| 37 |
+
DocScanner.reload_rec_model(cuda, net.bm, path_l)
|
| 38 |
net.eval()
|
| 39 |
|
| 40 |
return net
|