Update app.py
Browse files
app.py
CHANGED
|
@@ -36,12 +36,16 @@ class model:
|
|
| 36 |
def __init__(self):
|
| 37 |
self.model = None
|
| 38 |
self.model_name = None
|
|
|
|
| 39 |
|
| 40 |
def gradio_demo(self, model_name, sequence_input, image):
|
| 41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
if self.model_name != model_name:
|
|
|
|
|
|
|
| 43 |
self.model_name = model_name
|
| 44 |
model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
|
|
|
|
| 45 |
model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
|
| 46 |
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
|
| 47 |
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
|
|
|
|
| 36 |
def __init__(self):
|
| 37 |
self.model = None
|
| 38 |
self.model_name = None
|
| 39 |
+
self.model_path = None
|
| 40 |
|
| 41 |
def gradio_demo(self, model_name, sequence_input, image):
|
| 42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 43 |
if self.model_name != model_name:
|
| 44 |
+
if self.model_path is not None:
|
| 45 |
+
os.remove(self.model_path)
|
| 46 |
self.model_name = model_name
|
| 47 |
model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
|
| 48 |
+
self.model_path = model_ckpt_path
|
| 49 |
model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
|
| 50 |
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
|
| 51 |
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
|