""" Klarity HF Space - Core Processing Module Extracted and adapted from klarity.py for headless (server) use. Lite mode only, CPU device. """ import os import sys import shutil import subprocess import logging from pathlib import Path log = logging.getLogger(__name__) import torch import cv2 import numpy as np IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'} VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv', '.m4v', '.mpeg', '.mpg', '.3gp', '.ts', '.mts', '.m2ts', '.ogv'} NAFNET_CONFIGS_LITE = { 'deblur': { 'width': 32, 'middle_blk_num': 1, 'enc_blk_nums': [1, 1, 1, 28], 'dec_blk_nums': [1, 1, 1, 1], }, 'denoise': { 'width': 32, 'middle_blk_num': 12, 'enc_blk_nums': [2, 2, 4, 8], 'dec_blk_nums': [2, 2, 2, 2], }, } class ModelManager: """Loads and caches all lite AI models on CPU.""" def __init__(self, models_dir: str): self.models_dir = models_dir self.device = torch.device('cpu') self._denoise = None self._deblur = None self._upscale = None self._framegen = None # --- public lazy loaders --- def load_denoise(self): if self._denoise is not None: return self._denoise from nafnet_arch import NAFNet cfg = NAFNET_CONFIGS_LITE['denoise'] model = NAFNet(img_channel=3, **cfg) self._load_nafnet_weights(model, os.path.join(self.models_dir, 'denoise-lite.pth')) self._denoise = model.to(self.device).eval() return self._denoise def load_deblur(self): if self._deblur is not None: return self._deblur from nafnet_arch import NAFNetLocal cfg = NAFNET_CONFIGS_LITE['deblur'] model = NAFNetLocal(img_channel=3, **cfg) self._load_nafnet_weights(model, os.path.join(self.models_dir, 'deblur-lite.pth')) self._deblur = model.to(self.device).eval() return self._deblur def load_upscale(self): if self._upscale is not None: return self._upscale from sr_arch import SRVGGNetCompact model = SRVGGNetCompact( num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu', ) path = os.path.join(self.models_dir, 'upscale-lite.pth') ckpt = torch.load(path, map_location='cpu', weights_only=False) # upscale checkpoints use 'params_ema' or 'params' or raw state_dict sd = ckpt.get('params_ema', ckpt.get('params', ckpt)) sd = self._strip_state_dict({'params': sd}) model.load_state_dict(sd) self._upscale = model.to(self.device).eval() return self._upscale def load_framegen(self): if self._framegen is not None: return self._framegen from rife_arch import RIFE model = RIFE(mode='lite') model.load_model(self.models_dir, mode='lite') model.eval() model.device() self._framegen = model return self._framegen # --- helpers --- @staticmethod def _strip_state_dict(ckpt): sd = ckpt.get('params', ckpt.get('state_dict', ckpt)) for k in list(sd.keys()): if k.startswith('module.'): sd[k[7:]] = sd.pop(k) return sd def _load_nafnet_weights(self, model, path): ckpt = torch.load(path, map_location='cpu', weights_only=False) model.load_state_dict(self._strip_state_dict(ckpt)) # ------------------------------------------------------------------ # # Low-level tensor / image helpers # # ------------------------------------------------------------------ # def img2tensor(img, device): img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 return torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device) def tensor2img(tensor): arr = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() arr = np.clip(arr * 255, 0, 255).astype(np.uint8) return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) def pad_image(img, modulo=32): h, w = img.shape[2], img.shape[3] new_h = ((h - 1) // modulo + 1) * modulo new_w = ((w - 1) // modulo + 1) * modulo if new_h > h or new_w > w: img = torch.nn.functional.pad(img, (0, new_w - w, 0, new_h - h)) return img, (h, w) def run_nafnet(model, img_tensor): with torch.no_grad(): padded, (h, w) = pad_image(img_tensor) out = model(padded) return out[:, :, :h, :w] def run_upscale(model, img_tensor): with torch.no_grad(): padded, (h, w) = pad_image(img_tensor, modulo=4) out = model(padded) return out[:, :, : h * 4, : w * 4] # ------------------------------------------------------------------ # # Image processing functions # # ------------------------------------------------------------------ # def img_denoise(img, mm, cb=None): if cb: cb("Denoising...") m = mm.load_denoise() return tensor2img(run_nafnet(m, img2tensor(img, mm.device))) def img_deblur(img, mm, cb=None): if cb: cb("Deblurring...") m = mm.load_deblur() return tensor2img(run_nafnet(m, img2tensor(img, mm.device))) def img_upscale(img, mm, factor=4, cb=None): if cb: cb(f"Upscaling x{factor}...") m = mm.load_upscale() out = tensor2img(run_upscale(m, img2tensor(img, mm.device))) if factor == 2: h, w = out.shape[:2] out = cv2.resize(out, (w // 2, h // 2), interpolation=cv2.INTER_LANCZOS4) return out def img_clean(img, mm, cb=None): img = img_denoise(img, mm, cb) img = img_deblur(img, mm, cb) return img def img_full(img, mm, factor=4, cb=None): img = img_denoise(img, mm, cb) img = img_deblur(img, mm, cb) img = img_upscale(img, mm, factor, cb) return img IMAGE_FUNCS = { 'denoise': lambda img, mm, cb, f: img_denoise(img, mm, cb), 'deblur': lambda img, mm, cb, f: img_deblur(img, mm, cb), 'upscale': lambda img, mm, cb, f: img_upscale(img, mm, f, cb), 'clean': lambda img, mm, cb, f: img_clean(img, mm, cb), 'full': lambda img, mm, cb, f: img_full(img, mm, f, cb), } # ------------------------------------------------------------------ # # Video helpers # # ------------------------------------------------------------------ # def _run_ffmpeg(cmd, label="ffmpeg"): """Run an ffmpeg command and return (returncode, stderr_text). Logs stderr on failure and returns details so callers can raise meaningful errors instead of a bare CalledProcessError. """ log.info("Running: %s", " ".join(cmd)) r = subprocess.run(cmd, capture_output=True, text=True, timeout=600) if r.returncode != 0: stderr_snip = (r.stderr or "").strip().splitlines()[-3:] # last 3 lines log.error("%s failed (rc=%d):\n%s", label, r.returncode, r.stderr or "") raise RuntimeError( f"{label} failed (exit code {r.returncode}). " f"Last ffmpeg output:\n" + "\n".join(stderr_snip) ) return r def ensure_ffmpeg(): if shutil.which('ffmpeg') is None: raise RuntimeError("ffmpeg is not installed on this Space.") def video_info(path): cap = cv2.VideoCapture(path) if not cap.isOpened(): cap.release() raise RuntimeError(f"Cannot open video: {path}") fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() return fps, count, w, h def extract_frames(video_path, out_dir): os.makedirs(out_dir, exist_ok=True) cmd = ['ffmpeg', '-y', '-i', video_path, '-vsync', '0', os.path.join(out_dir, '%08d.png')] _run_ffmpeg(cmd, label="Frame extraction") frames = sorted(f for f in os.listdir(out_dir) if f.endswith('.png')) if not frames: raise RuntimeError( f"Frame extraction produced 0 frames. The video may be empty or unsupported." ) log.info("Extracted %d frames", len(frames)) return frames def extract_audio(video_path, audio_path): cmd = ['ffmpeg', '-y', '-i', video_path, '-vn', '-acodec', 'copy', audio_path] try: r = subprocess.run(cmd, capture_output=True, text=True, timeout=120) if r.returncode == 0 and os.path.isfile(audio_path): log.info("Audio extracted OK") return True except Exception as e: log.warning("Audio extraction failed (non-fatal): %s", e) # Audio extraction is non-fatal — video processing continues without audio return False def frames_to_video(frames_dir, out_path, fps, audio_path=None): tmp = out_path + '_temp.mp4' cmd = ['ffmpeg', '-y', '-framerate', str(fps), '-i', os.path.join(frames_dir, '%08d.png'), '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-crf', '18', tmp] _run_ffmpeg(cmd, label="Video compilation") if audio_path and os.path.isfile(audio_path): cmd2 = ['ffmpeg', '-y', '-i', tmp, '-i', audio_path, '-c:v', 'copy', '-c:a', 'aac', '-map', '0:v:0', '-map', '1:a:0?', out_path] r2 = subprocess.run(cmd2, capture_output=True, text=True, timeout=300) if r2.returncode == 0 and os.path.isfile(out_path): os.remove(tmp) return log.warning("Audio merge failed (rc=%d), using video without audio", r2.returncode) # Fallback: use video-only if os.path.exists(tmp): if os.path.exists(out_path): os.remove(out_path) os.rename(tmp, out_path) def process_video_frames(src_dir, dst_dir, frames, label, func, cb=None, cancel_event=None): os.makedirs(dst_dir, exist_ok=True) total = len(frames) for i, fname in enumerate(frames): if cancel_event and cancel_event.is_set(): raise RuntimeError("Processing cancelled") if cb: cb(f"{label} — frame {i + 1}/{total}") img = cv2.imread(os.path.join(src_dir, fname)) if img is None: raise RuntimeError( f"Failed to read frame {fname} from {src_dir}. " f"The frame file may be corrupted." ) img = func(img) cv2.imwrite(os.path.join(dst_dir, fname), img) # ------------------------------------------------------------------ # # RIFE frame generation # # ------------------------------------------------------------------ # def pad_for_rife(img, scale=1.0): div = max(64, int(64 / scale)) h, w = img.shape[:2] nh = ((h - 1) // div + 1) * div nw = ((w - 1) // div + 1) * div if nh > h or nw > w: img = np.pad(img, ((0, nh - h), (0, nw - w), (0, 0)), mode='edge') return img, (h, w) def generate_frames(src_dir, dst_dir, multi, mm, cb=None, cancel_event=None): model = mm.load_framegen() frames = sorted(f for f in os.listdir(src_dir) if f.endswith('.png')) if len(frames) < 2: raise ValueError("Need at least 2 frames for interpolation.") os.makedirs(dst_dir, exist_ok=True) idx = 0 for i in range(len(frames) - 1): if cancel_event and cancel_event.is_set(): raise RuntimeError("Processing cancelled") if cb: cb(f"Interpolating — pair {i + 1}/{len(frames) - 1}") img0 = cv2.imread(os.path.join(src_dir, frames[i])) img1 = cv2.imread(os.path.join(src_dir, frames[i + 1])) img0, (oh, ow) = pad_for_rife(img0) img1, _ = pad_for_rife(img1) t0 = img2tensor(img0, mm.device) t1 = img2tensor(img1, mm.device) cv2.imwrite(os.path.join(dst_dir, f'{idx:08d}.png'), img0[:oh, :ow]) idx += 1 for j in range(multi - 1): ts = (j + 1) / multi with torch.no_grad(): mid = model.inference(t0, t1, ts, 1.0) cv2.imwrite(os.path.join(dst_dir, f'{idx:08d}.png'), tensor2img(mid)[:oh, :ow]) idx += 1 last = cv2.imread(os.path.join(src_dir, frames[-1])) cv2.imwrite(os.path.join(dst_dir, f'{idx:08d}.png'), last) return idx + 1 def blend_frames_for_fps(frames_dir, target_fps, generated_fps, cb=None): """Blend generated frames down to a lower target FPS using ffmpeg minterpolate. Matches the desktop Klarity behaviour: when the user requests a target FPS that is below the generated (max) FPS, frames are blended to produce a smooth video at the requested rate. """ if generated_fps <= 0: return frames_dir ratio = target_fps / generated_fps if ratio >= 0.99: # No meaningful difference, skip blending return frames_dir blended_dir = frames_dir + '_blended' os.makedirs(blended_dir, exist_ok=True) if cb: cb(f"Blending frames to {target_fps:.1f} FPS...") cmd = [ 'ffmpeg', '-y', '-framerate', str(generated_fps), '-i', os.path.join(frames_dir, '%08d.png'), '-vf', f'minterpolate=fps={target_fps:.2f}:mi_mode=blend', '-vsync', '0', os.path.join(blended_dir, '%08d.png'), ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode == 0 and os.path.exists(blended_dir): blended_frames = sorted(f for f in os.listdir(blended_dir) if f.endswith('.png')) if len(blended_frames) > 0: return blended_dir # Fallback: if blending failed, return original log.warning("Frame blending failed, using generated frames as-is") if os.path.exists(blended_dir): shutil.rmtree(blended_dir) return frames_dir def clamp_fps(fps, orig_fps, multi): """Clamp user-provided FPS to valid range [min, max] matching desktop Klarity. - fps is None -> use max_fps (default) - fps < min_fps -> warning + use max_fps - fps > max_fps -> warning + use max_fps - otherwise -> use fps as-is Returns the clamped FPS value. """ min_fps = orig_fps max_fps = orig_fps * multi if fps is None: return max_fps if fps < min_fps: log.warning( "Target FPS %.2f below minimum (%.2f). Using max: %.2f", fps, min_fps, max_fps, ) return max_fps if fps > max_fps: log.warning( "Target FPS %.2f exceeds maximum (%.2f). Using max: %.2f", fps, max_fps, max_fps, ) return max_fps return fps # ------------------------------------------------------------------ # # High-level process dispatcher # # ------------------------------------------------------------------ # def get_supported_modes(path): ext = Path(path).suffix.lower() if ext in IMAGE_EXTENSIONS: return ['denoise', 'deblur', 'upscale', 'clean', 'full'] if ext in VIDEO_EXTENSIONS: return ['denoise', 'deblur', 'upscale', 'clean', 'full', 'frame-gen', 'clean-frame-gen', 'full-frame-gen'] return [] MODE_SETTINGS = { 'denoise': {'upscale': False, 'multi': False, 'fps': False}, 'deblur': {'upscale': False, 'multi': False, 'fps': False}, 'upscale': {'upscale': True, 'multi': False, 'fps': False}, 'clean': {'upscale': False, 'multi': False, 'fps': False}, 'full': {'upscale': True, 'multi': False, 'fps': False}, 'frame-gen': {'upscale': False, 'multi': True, 'fps': True}, 'clean-frame-gen': {'upscale': False, 'multi': True, 'fps': True}, 'full-frame-gen': {'upscale': True, 'multi': True, 'fps': True}, } MODE_SUFFIXES = { 'denoise': '_denoised', 'deblur': '_deblurred', 'upscale': '_upscaled', 'clean': '_cleaned', 'full': '_enhanced', 'frame-gen': '_generated', 'clean-frame-gen': '_clean_generated', 'full-frame-gen': '_full_enhanced', } def process_file(input_path, mode, mm, out_dir, *, upscale_factor=4, multi=2, fps=None, cb=None, cancel_event=None): ext = Path(input_path).suffix.lower() # If extension not recognized, try MIME-based detection if ext not in IMAGE_EXTENSIONS and ext not in VIDEO_EXTENSIONS: import mimetypes mime, _ = mimetypes.guess_type(input_path) log.warning("Unrecognized extension '%s', MIME: %s", ext, mime) if mime: if mime.startswith('image/'): return _proc_image(input_path, mode, mm, out_dir, upscale_factor, cb, cancel_event) if mime.startswith('video/'): return _proc_video(input_path, mode, mm, out_dir, upscale_factor, multi, fps, cb, cancel_event) raise ValueError(f"Unsupported format: {ext} (MIME: {mime})") if ext in IMAGE_EXTENSIONS: return _proc_image(input_path, mode, mm, out_dir, upscale_factor, cb, cancel_event) return _proc_video(input_path, mode, mm, out_dir, upscale_factor, multi, fps, cb, cancel_event) def _proc_image(path, mode, mm, out_dir, uf, cb, cancel_event=None): if cb: cb("Reading image…") img = cv2.imread(path) if img is None: raise ValueError(f"Cannot read image: {path}") if cb: cb(f"Processing ({mode})…") result = IMAGE_FUNCS[mode](img, mm, cb, uf) name = Path(path).stem + MODE_SUFFIXES.get(mode, '_processed') + Path(path).suffix out = os.path.join(out_dir, name) cv2.imwrite(out, result) return {'type': 'image', 'before': path, 'after': out} def _proc_video(path, mode, mm, out_dir, uf, multi, fps, cb, cancel_event=None): log.info("_proc_video: path=%s mode=%s uf=%s multi=%s fps=%s", os.path.basename(path), mode, uf, multi, fps) ensure_ffmpeg() orig_fps, frame_count, w, h = video_info(path) log.info("Video info: %.2f FPS, %d frames, %dx%d", orig_fps, frame_count, w, h) tmp = os.path.join(out_dir, 'tmp') if os.path.exists(tmp): shutil.rmtree(tmp) os.makedirs(tmp) orig_exc = None try: frames_dir = os.path.join(tmp, 'frames') audio_path = os.path.join(tmp, 'audio.aac') log.info("Step 1: Extracting frames from %s", os.path.basename(path)) if cancel_event and cancel_event.is_set(): raise RuntimeError("Processing cancelled") if cb: cb("Extracting frames…") frames = extract_frames(path, frames_dir) log.info("Step 2: Extracting audio") extract_audio(path, audio_path) if cancel_event and cancel_event.is_set(): raise RuntimeError("Processing cancelled") fg_modes = {'frame-gen', 'clean-frame-gen', 'full-frame-gen'} if mode in fg_modes: cur = frames_dir if mode in ('clean-frame-gen', 'full-frame-gen'): log.info("Step 3a: Denoising %d frames", len(frames)) d = os.path.join(tmp, 'denoised') process_video_frames(cur, d, frames, "Denoising", lambda img: img_denoise(img, mm), cancel_event=cancel_event) cur = d log.info("Step 3b: Deblurring %d frames", len(frames)) d = os.path.join(tmp, 'cleaned') process_video_frames(cur, d, frames, "Deblurring", lambda img: img_deblur(img, mm), cancel_event=cancel_event) cur = d if mode == 'full-frame-gen': log.info("Step 3c: Upscaling x%d, %d frames", uf, len(frames)) d = os.path.join(tmp, 'upscaled') process_video_frames(cur, d, frames, f"Upscaling x{uf}", lambda img: img_upscale(img, mm, uf), cancel_event=cancel_event) cur = d frames = sorted(f for f in os.listdir(cur) if f.endswith('.png')) max_fps = orig_fps * multi target_fps = clamp_fps(fps, orig_fps, multi) log.info("Step 4: Frame generation (multi=%d, target_fps=%.1f)", multi, target_fps) if cb: cb("Generating frames…") gen = os.path.join(tmp, 'generated') generate_frames(cur, gen, multi, mm, cancel_event=cancel_event) final_frames = gen if target_fps < max_fps: log.info("Step 5: Blending to %.1f FPS", target_fps) final_frames = blend_frames_for_fps(gen, target_fps, max_fps, cb) log.info("Step 6: Compiling video at %.1f FPS", target_fps) if cb: cb("Compiling video…") name = Path(path).stem + MODE_SUFFIXES.get(mode, '_processed') + '.mp4' out = os.path.join(out_dir, name) frames_to_video(final_frames, out, target_fps, audio_path if os.path.exists(audio_path) else None) else: steps = [] if mode in ('denoise', 'clean', 'full'): steps.append(("Denoising", lambda img, c=None: img_denoise(img, mm, c))) if mode in ('deblur', 'clean', 'full'): steps.append(("Deblurring", lambda img, c=None: img_deblur(img, mm, c))) if mode in ('upscale', 'full'): steps.append((f"Upscaling x{uf}", lambda img, c=None: img_upscale(img, mm, uf, c))) cur = frames_dir for step_i, (label, fn) in enumerate(steps): d = os.path.join(tmp, f'step_{step_i}') log.info("Step 3.%d: %s (%d frames)", step_i, label, len(frames)) process_video_frames(cur, d, frames, label, fn, cancel_event=cancel_event) cur = d log.info("Step 4: Compiling video at %.2f FPS", orig_fps) if cb: cb("Compiling video…") name = Path(path).stem + MODE_SUFFIXES.get(mode, '_processed') + '.mp4' out = os.path.join(out_dir, name) frames_to_video(cur, out, orig_fps, audio_path if os.path.exists(audio_path) else None) log.info("Video processing complete: %s", out) return {'type': 'video', 'before': path, 'after': out} except Exception as e: orig_exc = e raise finally: # Cleanup tmp — but never let cleanup errors mask the real error try: if os.path.exists(tmp): shutil.rmtree(tmp) except Exception as cleanup_err: log.error("Cleanup failed (non-fatal): %s", cleanup_err) if orig_exc is None: # Only re-raise cleanup error if there was no original error raise def is_image(path): return Path(path).suffix.lower() in IMAGE_EXTENSIONS def is_video(path): return Path(path).suffix.lower() in VIDEO_EXTENSIONS