Spaces:
Paused
Paused
| import torch | |
| import numpy as np | |
| from torchvision.transforms.functional import to_tensor, to_pil_image | |
| from pathlib import Path | |
| import os | |
| import gc | |
| from huggingface_hub import snapshot_download | |
| from .RIFE.RIFE_HDv3 import Model as RIFEBaseModel | |
| from .message_manager import MessageManager | |
| import devicetorch | |
| # Get the directory of the current script (rife_core.py) | |
| _MODULE_DIR = Path(os.path.dirname(os.path.abspath(__file__))) # __file__ gives path to current script | |
| # MODEL_RIFE_PATH = "model_rife" # OLD - this is relative to CWD | |
| MODEL_RIFE_PATH = _MODULE_DIR / "model_rife" # NEW - relative to this script's location | |
| RIFE_MODEL_FILENAME = "flownet.pkl" | |
| class RIFEHandler: | |
| def __init__(self, message_manager: MessageManager = None): | |
| self.message_manager = message_manager if message_manager else MessageManager() | |
| self.model_dir = Path(MODEL_RIFE_PATH) # Path() constructor handles Path objects correctly | |
| self.model_file_path = self.model_dir / RIFE_MODEL_FILENAME | |
| self.rife_model = None | |
| def _log(self, message, level="INFO"): | |
| # Helper for logging using the MessageManager | |
| if level.upper() == "ERROR": | |
| self.message_manager.add_error(f"RIFEHandler: {message}") | |
| elif level.upper() == "WARNING": | |
| self.message_manager.add_warning(f"RIFEHandler: {message}") | |
| else: | |
| self.message_manager.add_message(f"RIFEHandler: {message}") | |
| def _ensure_model_downloaded_and_loaded(self) -> bool: | |
| if self.rife_model is not None: | |
| self._log("RIFE model already loaded.") | |
| return True | |
| # self.model_dir is now an absolute path | |
| if not self.model_dir.exists(): | |
| os.makedirs(self.model_dir, exist_ok=True) | |
| self._log(f"Created RIFE model directory: {self.model_dir}") | |
| # self.model_file_path is now an absolute path | |
| if not self.model_file_path.exists(): | |
| self._log("RIFE model weights not found. Downloading...") | |
| try: | |
| snapshot_download( | |
| repo_id="AlexWortega/RIFE", | |
| allow_patterns=["*.pkl", "*.pth"], | |
| local_dir=self.model_dir, # Pass the absolute path | |
| local_dir_use_symlinks=False | |
| ) | |
| if self.model_file_path.exists(): | |
| self._log("RIFE model weights downloaded successfully.") | |
| else: | |
| self._log(f"RIFE model download completed, but {RIFE_MODEL_FILENAME} not found in {self.model_dir}. Check allow_patterns and repo structure.", "ERROR") | |
| return False | |
| except Exception as e: | |
| self._log(f"Failed to download RIFE model weights: {e}", "ERROR") | |
| return False | |
| if not self.model_file_path.exists(): | |
| self._log(f"RIFE model file {self.model_file_path} does not exist. Cannot load model.", "ERROR") | |
| return False | |
| try: | |
| self._log(f"Loading RIFE model from {self.model_dir}...") # self.model_dir is absolute | |
| current_device_str = devicetorch.get(torch) | |
| self.rife_model = RIFEBaseModel(local_rank=-1) | |
| self.rife_model.load_model(str(self.model_dir), -1) # str(self.model_dir) is absolute | |
| self.rife_model.eval() | |
| self._log(f"RIFE model loaded successfully to its determined device.") | |
| return True | |
| except Exception as e: | |
| self._log(f"Failed to load RIFE model: {e}", "ERROR") | |
| import traceback | |
| self._log(f"Traceback: {traceback.format_exc()}", "ERROR") | |
| self.rife_model = None | |
| return False | |
| def unload_model(self): | |
| if self.rife_model is not None: | |
| self._log("Unloading RIFE model...") | |
| del self.rife_model | |
| self.rife_model = None | |
| devicetorch.empty_cache(torch) | |
| gc.collect() | |
| self._log("RIFE model unloaded and memory cleared.") | |
| else: | |
| self._log("RIFE model not loaded, no need to unload.") | |
| def interpolate_between_frames(self, frame1_np: np.ndarray, frame2_np: np.ndarray) -> np.ndarray | None: | |
| if self.rife_model is None: | |
| self._log("RIFE model not loaded. Call _ensure_model_downloaded_and_loaded() before interpolation.", "ERROR") | |
| return None | |
| try: | |
| img0_tensor = to_tensor(frame1_np).unsqueeze(0) | |
| img1_tensor = to_tensor(frame2_np).unsqueeze(0) | |
| img0 = devicetorch.to(torch, img0_tensor) | |
| img1 = devicetorch.to(torch, img1_tensor) | |
| required_multiple = 32 | |
| h_orig, w_orig = img0.shape[2], img0.shape[3] | |
| pad_h = (required_multiple - h_orig % required_multiple) % required_multiple | |
| pad_w = (required_multiple - w_orig % required_multiple) % required_multiple | |
| if pad_h > 0 or pad_w > 0: | |
| img0 = torch.nn.functional.pad(img0, (0, pad_w, 0, pad_h), mode='replicate') | |
| img1 = torch.nn.functional.pad(img1, (0, pad_w, 0, pad_h), mode='replicate') | |
| with torch.no_grad(): | |
| middle_frame_tensor = self.rife_model.inference(img0, img1, scale=1.0) | |
| if pad_h > 0 or pad_w > 0: | |
| middle_frame_tensor = middle_frame_tensor[:, :, :h_orig, :w_orig] | |
| middle_frame_pil = to_pil_image(middle_frame_tensor.squeeze(0).cpu()) | |
| return np.array(middle_frame_pil) | |
| except Exception as e: | |
| self._log(f"Error during RIFE frame interpolation: {e}", "ERROR") | |
| import traceback | |
| self._log(f"Traceback: {traceback.format_exc()}", "ERROR") | |
| if "out of memory" in str(e).lower(): | |
| devicetorch.empty_cache(torch) | |
| return None |