marduk-ra commited on
Commit
e9379ec
·
verified ·
1 Parent(s): 5e01da9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +2 -13
  2. requirements.txt +2 -2
app.py CHANGED
@@ -446,7 +446,6 @@ class ModelState:
446
  """Global state for loaded model."""
447
 
448
  model: Optional[FeatureFusionModel] = None
449
- device: str = "cpu"
450
 
451
 
452
  state = ModelState()
@@ -457,15 +456,13 @@ def load_model() -> str:
457
  if state.model is not None:
458
  return "Model already loaded"
459
 
460
- device = "cuda" if torch.cuda.is_available() else "cpu"
461
-
462
  try:
463
  checkpoint_path = hf_hub_download(
464
  repo_id="marduk-ra/MFIR",
465
  filename="temporal_fusion_model.pth",
466
  )
467
 
468
- ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
469
 
470
  config_dict = ckpt["config"]
471
  if isinstance(config_dict, dict):
@@ -482,13 +479,11 @@ def load_model() -> str:
482
  config = FeatureFusionConfig.from_dict(config_dict)
483
  model = FeatureFusionModel(config)
484
  model.load_state_dict_with_compatibility(ckpt["state_dict"])
485
- model.to(device)
486
  model.eval()
487
 
488
  state.model = model
489
- state.device = device
490
 
491
- return f"Model loaded on {device}"
492
 
493
  except Exception as e:
494
  return f"Error loading model: {e}"
@@ -609,8 +604,6 @@ def process_images(
609
  if frames is None:
610
  return None, "Failed to preprocess images"
611
 
612
- frames = frames.to(state.device)
613
-
614
  ref_idx = min(ref_frame, frames.shape[1] - 1)
615
 
616
  with torch.no_grad():
@@ -619,10 +612,6 @@ def process_images(
619
  output = result["output"]
620
  output_pil = postprocess_output(output, preprocess_info)
621
 
622
- del result, output, frames
623
- if state.device == "cuda":
624
- torch.cuda.empty_cache()
625
-
626
  return output_pil, f"Processed {len(images)} frames (reference: frame {ref_idx})"
627
 
628
  except Exception as e:
 
446
  """Global state for loaded model."""
447
 
448
  model: Optional[FeatureFusionModel] = None
 
449
 
450
 
451
  state = ModelState()
 
456
  if state.model is not None:
457
  return "Model already loaded"
458
 
 
 
459
  try:
460
  checkpoint_path = hf_hub_download(
461
  repo_id="marduk-ra/MFIR",
462
  filename="temporal_fusion_model.pth",
463
  )
464
 
465
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
466
 
467
  config_dict = ckpt["config"]
468
  if isinstance(config_dict, dict):
 
479
  config = FeatureFusionConfig.from_dict(config_dict)
480
  model = FeatureFusionModel(config)
481
  model.load_state_dict_with_compatibility(ckpt["state_dict"])
 
482
  model.eval()
483
 
484
  state.model = model
 
485
 
486
+ return "Model loaded"
487
 
488
  except Exception as e:
489
  return f"Error loading model: {e}"
 
604
  if frames is None:
605
  return None, "Failed to preprocess images"
606
 
 
 
607
  ref_idx = min(ref_frame, frames.shape[1] - 1)
608
 
609
  with torch.no_grad():
 
612
  output = result["output"]
613
  output_pil = postprocess_output(output, preprocess_info)
614
 
 
 
 
 
615
  return output_pil, f"Processed {len(images)} frames (reference: frame {ref_idx})"
616
 
617
  except Exception as e:
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- torch>=2.1.0
2
- torchvision>=0.16.0
3
  huggingface_hub
4
  numpy
5
  Pillow
 
1
+ torch
2
+ torchvision
3
  huggingface_hub
4
  numpy
5
  Pillow