Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -39,27 +39,25 @@ models = [
|
|
| 39 |
|
| 40 |
def setup_model(config_file, model_path=None):
|
| 41 |
cfg = get_cfg()
|
| 42 |
-
|
| 43 |
-
if
|
| 44 |
-
cfg.merge_from_file(model_zoo.get_config_file(config_file))
|
| 45 |
-
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file)
|
| 46 |
-
else:
|
| 47 |
cfg.merge_from_file(config_file)
|
| 48 |
cfg.MODEL.WEIGHTS = model_path
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
| 50 |
if not torch.cuda.is_available():
|
| 51 |
cfg.MODEL.DEVICE = "cpu"
|
| 52 |
-
|
| 53 |
return cfg
|
| 54 |
|
| 55 |
for model in models:
|
| 56 |
if model["name"] == "Custom Model":
|
| 57 |
model["cfg"] = setup_model(model["config_file"], model["model_path"])
|
|
|
|
| 58 |
else:
|
| 59 |
model["cfg"] = setup_model(model["config_file"])
|
| 60 |
-
if model["name"] == "Custom Model":
|
| 61 |
-
model["metadata"] = MetadataCatalog.get("teng-valid")
|
| 62 |
-
else:
|
| 63 |
model["metadata"] = MetadataCatalog.get(model["cfg"].DATASETS.TRAIN[0])
|
| 64 |
|
| 65 |
def inference(image_url, image, min_score, model_name):
|
|
|
|
| 39 |
|
| 40 |
def setup_model(config_file, model_path=None):
|
| 41 |
cfg = get_cfg()
|
| 42 |
+
|
| 43 |
+
if model_path:
|
|
|
|
|
|
|
|
|
|
| 44 |
cfg.merge_from_file(config_file)
|
| 45 |
cfg.MODEL.WEIGHTS = model_path
|
| 46 |
+
else:
|
| 47 |
+
cfg.merge_from_file(model_zoo.get_config_file(config_file))
|
| 48 |
+
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file)
|
| 49 |
+
|
| 50 |
if not torch.cuda.is_available():
|
| 51 |
cfg.MODEL.DEVICE = "cpu"
|
| 52 |
+
|
| 53 |
return cfg
|
| 54 |
|
| 55 |
for model in models:
|
| 56 |
if model["name"] == "Custom Model":
|
| 57 |
model["cfg"] = setup_model(model["config_file"], model["model_path"])
|
| 58 |
+
model["metadata"] = MetadataCatalog.get("teng-valid")
|
| 59 |
else:
|
| 60 |
model["cfg"] = setup_model(model["config_file"])
|
|
|
|
|
|
|
|
|
|
| 61 |
model["metadata"] = MetadataCatalog.get(model["cfg"].DATASETS.TRAIN[0])
|
| 62 |
|
| 63 |
def inference(image_url, image, min_score, model_name):
|