Upload 2 files
Browse files- app.py +2 -13
- requirements.txt +2 -2
app.py
CHANGED
|
@@ -446,7 +446,6 @@ class ModelState:
|
|
| 446 |
"""Global state for loaded model."""
|
| 447 |
|
| 448 |
model: Optional[FeatureFusionModel] = None
|
| 449 |
-
device: str = "cpu"
|
| 450 |
|
| 451 |
|
| 452 |
state = ModelState()
|
|
@@ -457,15 +456,13 @@ def load_model() -> str:
|
|
| 457 |
if state.model is not None:
|
| 458 |
return "Model already loaded"
|
| 459 |
|
| 460 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 461 |
-
|
| 462 |
try:
|
| 463 |
checkpoint_path = hf_hub_download(
|
| 464 |
repo_id="marduk-ra/MFIR",
|
| 465 |
filename="temporal_fusion_model.pth",
|
| 466 |
)
|
| 467 |
|
| 468 |
-
ckpt = torch.load(checkpoint_path, map_location=
|
| 469 |
|
| 470 |
config_dict = ckpt["config"]
|
| 471 |
if isinstance(config_dict, dict):
|
|
@@ -482,13 +479,11 @@ def load_model() -> str:
|
|
| 482 |
config = FeatureFusionConfig.from_dict(config_dict)
|
| 483 |
model = FeatureFusionModel(config)
|
| 484 |
model.load_state_dict_with_compatibility(ckpt["state_dict"])
|
| 485 |
-
model.to(device)
|
| 486 |
model.eval()
|
| 487 |
|
| 488 |
state.model = model
|
| 489 |
-
state.device = device
|
| 490 |
|
| 491 |
-
return
|
| 492 |
|
| 493 |
except Exception as e:
|
| 494 |
return f"Error loading model: {e}"
|
|
@@ -609,8 +604,6 @@ def process_images(
|
|
| 609 |
if frames is None:
|
| 610 |
return None, "Failed to preprocess images"
|
| 611 |
|
| 612 |
-
frames = frames.to(state.device)
|
| 613 |
-
|
| 614 |
ref_idx = min(ref_frame, frames.shape[1] - 1)
|
| 615 |
|
| 616 |
with torch.no_grad():
|
|
@@ -619,10 +612,6 @@ def process_images(
|
|
| 619 |
output = result["output"]
|
| 620 |
output_pil = postprocess_output(output, preprocess_info)
|
| 621 |
|
| 622 |
-
del result, output, frames
|
| 623 |
-
if state.device == "cuda":
|
| 624 |
-
torch.cuda.empty_cache()
|
| 625 |
-
|
| 626 |
return output_pil, f"Processed {len(images)} frames (reference: frame {ref_idx})"
|
| 627 |
|
| 628 |
except Exception as e:
|
|
|
|
| 446 |
"""Global state for loaded model."""
|
| 447 |
|
| 448 |
model: Optional[FeatureFusionModel] = None
|
|
|
|
| 449 |
|
| 450 |
|
| 451 |
state = ModelState()
|
|
|
|
| 456 |
if state.model is not None:
|
| 457 |
return "Model already loaded"
|
| 458 |
|
|
|
|
|
|
|
| 459 |
try:
|
| 460 |
checkpoint_path = hf_hub_download(
|
| 461 |
repo_id="marduk-ra/MFIR",
|
| 462 |
filename="temporal_fusion_model.pth",
|
| 463 |
)
|
| 464 |
|
| 465 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 466 |
|
| 467 |
config_dict = ckpt["config"]
|
| 468 |
if isinstance(config_dict, dict):
|
|
|
|
| 479 |
config = FeatureFusionConfig.from_dict(config_dict)
|
| 480 |
model = FeatureFusionModel(config)
|
| 481 |
model.load_state_dict_with_compatibility(ckpt["state_dict"])
|
|
|
|
| 482 |
model.eval()
|
| 483 |
|
| 484 |
state.model = model
|
|
|
|
| 485 |
|
| 486 |
+
return "Model loaded"
|
| 487 |
|
| 488 |
except Exception as e:
|
| 489 |
return f"Error loading model: {e}"
|
|
|
|
| 604 |
if frames is None:
|
| 605 |
return None, "Failed to preprocess images"
|
| 606 |
|
|
|
|
|
|
|
| 607 |
ref_idx = min(ref_frame, frames.shape[1] - 1)
|
| 608 |
|
| 609 |
with torch.no_grad():
|
|
|
|
| 612 |
output = result["output"]
|
| 613 |
output_pil = postprocess_output(output, preprocess_info)
|
| 614 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
return output_pil, f"Processed {len(images)} frames (reference: frame {ref_idx})"
|
| 616 |
|
| 617 |
except Exception as e:
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
torch
|
| 2 |
-
torchvision
|
| 3 |
huggingface_hub
|
| 4 |
numpy
|
| 5 |
Pillow
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
huggingface_hub
|
| 4 |
numpy
|
| 5 |
Pillow
|