Klarity / src /klarity.py
HAKORADev's picture
Upload folder using huggingface_hub
e0c50e8 verified
import os
import sys
import argparse
import glob
import shutil
import subprocess
import time
import json
import threading
from pathlib import Path
from datetime import datetime, timedelta
import torch
import cv2
import numpy as np
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
JSON_PROGRESS = False
try:
from model_downloader import ensure_models, check_internet_connection, set_model_mode, get_model_mode, get_model_paths_for_mode, MODEL_INFO
except ImportError:
def ensure_models(*args, **kwargs):
print("Warning: model_downloader module not found. Auto-download disabled.")
return True
def check_internet_connection():
return True
def set_model_mode(mode):
pass
def get_model_mode():
return 'heavy'
def get_model_paths_for_mode(script_dir, mode=None):
return {
'deblur': os.path.join(script_dir, 'models', f'deblur-{mode or "heavy"}.pth'),
'denoise': os.path.join(script_dir, 'models', f'denoise-{mode or "heavy"}.pth'),
'upscale': os.path.join(script_dir, 'models', f'upscale-{mode or "heavy"}.pth'),
'rife': os.path.join(script_dir, 'models', f'framegen-{mode or "heavy"}.pkl'),
}
MODEL_INFO = {}
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(SCRIPT_DIR, "models")
TEMP_DIR = os.path.join(SCRIPT_DIR, "tmp")
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'}
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv', '.m4v'}
device = None
device_preference = None
NAFNET_CONFIGS = {
'deblur': {
'heavy': {
'width': 64,
'middle_blk_num': 1,
'enc_blk_nums': [1, 1, 1, 28],
'dec_blk_nums': [1, 1, 1, 1],
},
'lite': {
'width': 32,
'middle_blk_num': 1,
'enc_blk_nums': [1, 1, 1, 28],
'dec_blk_nums': [1, 1, 1, 1],
},
},
'denoise': {
'heavy': {
'width': 64,
'middle_blk_num': 12,
'enc_blk_nums': [2, 2, 4, 8],
'dec_blk_nums': [2, 2, 2, 2],
},
'lite': {
'width': 32,
'middle_blk_num': 12,
'enc_blk_nums': [2, 2, 4, 8],
'dec_blk_nums': [2, 2, 2, 2],
},
},
}
class ProgressTracker:
def __init__(self):
self.total_files = 0
self.current_file_idx = 0
self.start_time = None
self.current_file_start = None
self.current_step = ""
self.current_file_name = ""
self.file_times = []
self._last_update = 0
def start_batch(self, total_files):
self.total_files = total_files
self.current_file_idx = 0
self.start_time = time.time()
self.file_times = []
def start_file(self, file_name):
self.current_file_idx += 1
self.current_file_name = os.path.basename(file_name)
self.current_file_start = time.time()
self.current_step = "Loading"
def set_step(self, step_name):
self.current_step = step_name
def finish_file(self):
if self.current_file_start:
elapsed = time.time() - self.current_file_start
self.file_times.append(elapsed)
return elapsed
return 0
def get_elapsed_str(self):
if not self.start_time:
return "00:00"
elapsed = time.time() - self.start_time
return self._format_time(elapsed)
def get_eta_str(self):
if len(self.file_times) < 2:
return "calculating..."
avg_time = sum(self.file_times) / len(self.file_times)
remaining_files = self.total_files - self.current_file_idx
eta_seconds = avg_time * remaining_files
return self._format_time(eta_seconds)
def _format_time(self, seconds):
if seconds < 3600:
return time.strftime("%M:%S", time.gmtime(seconds))
return time.strftime("%H:%M:%S", time.gmtime(seconds))
def print_status(self, force=False):
global JSON_PROGRESS
now = time.time()
if not force and (now - self._last_update) < 0.1:
return
self._last_update = now
elapsed = self.get_elapsed_str()
eta = self.get_eta_str()
progress_pct = (self.current_file_idx / self.total_files * 100) if self.total_files > 0 else 0
display_name = self.current_file_name
if len(display_name) > 35:
display_name = display_name[:32] + "..."
if JSON_PROGRESS:
json_output = json.dumps({
'percent': int(progress_pct),
'step': self.current_step,
'file': display_name,
'file_num': self.current_file_idx,
'total_files': self.total_files,
'elapsed': elapsed,
'eta': eta
})
print(json_output)
sys.stdout.flush()
else:
status = f"\r[{self.current_file_idx}/{self.total_files}] ({progress_pct:5.1f}%) | {elapsed} elapsed, ETA: {eta} | {display_name} | {self.current_step}"
status = status.ljust(120)
sys.stdout.write(status)
sys.stdout.flush()
def print_newline(self):
sys.stdout.write("\n")
sys.stdout.flush()
progress = ProgressTracker()
class StepProgressBar:
def __init__(self, steps, file_name=""):
self.steps = steps
self.current_step_idx = 0
self.file_name = file_name
self.bar_width = 30
def update(self, step_name):
self.current_step_idx += 1
progress_pct = self.current_step_idx / self.steps * 100
filled = int(self.bar_width * self.current_step_idx / self.steps)
bar = "█" * filled + "░" * (self.bar_width - filled)
display_name = self.file_name
if len(display_name) > 25:
display_name = display_name[:22] + "..."
status = f"\r {display_name} [{bar}] {self.current_step_idx}/{self.steps} ({progress_pct:5.1f}%) - {step_name}"
status = status.ljust(100)
sys.stdout.write(status)
sys.stdout.flush()
def finish(self):
bar = "█" * self.bar_width
status = f"\r {self.file_name[:25]:<25} [{bar}] {self.steps}/{self.steps} (100.0%) - Done!"
status = status.ljust(100)
sys.stdout.write(status)
sys.stdout.flush()
print()
def get_model_paths():
mode = get_model_mode()
return get_model_paths_for_mode(SCRIPT_DIR, mode)
def check_and_download_models():
mode = get_model_mode()
model_paths = get_model_paths()
ensure_models(SCRIPT_DIR, model_paths, auto_download=True, prompt=True, mode=mode)
def check_gpu():
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
return True, gpu_name, f"{gpu_memory:.1f}GB"
return False, "CPU", "N/A"
def get_device(force_cpu=False, device_type=None):
global device
if device is not None:
return device
if force_cpu or device_type == 'cpu':
device = torch.device('cpu')
print("Device: CPU")
return device
if device_type == 'gpu':
if torch.cuda.is_available():
device = torch.device('cuda')
print(f"Device: GPU ({torch.cuda.get_device_name(0)})")
else:
print("GPU requested but not available - falling back to CPU")
device = torch.device('cpu')
return device
if torch.cuda.is_available():
device = torch.device('cuda')
print(f"Device: GPU ({torch.cuda.get_device_name(0)})")
else:
device = torch.device('cpu')
print("Device: CPU (no GPU available)")
return device
def select_device():
global device, device_preference
has_gpu, gpu_name, gpu_memory = check_gpu()
print("\n" + "-"*40)
print("Select device:")
if has_gpu:
print(f" 1. CPU")
print(f" 2. GPU ({gpu_name}, {gpu_memory})")
else:
print(f" 1. CPU (only option available)")
print(f" 2. GPU (not available)")
print("\n Press Enter for auto-detect (GPU if available, else CPU)")
choice = input("> ").strip()
if choice == '':
if has_gpu:
device = torch.device('cuda')
print(f"\nUsing GPU: {gpu_name} ({gpu_memory})")
else:
device = torch.device('cpu')
print("\nUsing CPU (no GPU available)")
elif choice == '1':
device = torch.device('cpu')
print("\nUsing CPU")
elif choice == '2':
if has_gpu:
device = torch.device('cuda')
print(f"\nUsing GPU: {gpu_name} ({gpu_memory})")
else:
print("\nGPU requested but not available - falling back to CPU")
device = torch.device('cpu')
else:
print("\nInvalid choice, auto-detecting...")
if has_gpu:
device = torch.device('cuda')
print(f"Using GPU: {gpu_name} ({gpu_memory})")
else:
device = torch.device('cpu')
print("Using CPU")
return device
def select_model_mode():
print("\n" + "="*60)
print("SELECT MODEL MODE")
print("="*60)
print("\n 1. Heavy - Better quality, larger models (default)")
print(" 2. Lite - Faster processing, smaller models")
print("")
print(" Heavy models: NAFNet-width64, Real-ESRGAN-x4plus, RIFE-v4.25")
print(" Lite models: NAFNet-width32, Real-ESRGAN-general-x4v3, RIFE-v4.17")
while True:
choice = input("\nSelect mode (1 or 2): ").strip()
if choice == '1' or choice == '':
return 'heavy'
elif choice == '2':
return 'lite'
else:
print(f"Invalid input: '{choice}'. Please enter 1 for Heavy or 2 for Lite.")
deblur_model = None
denoise_model = None
upscale_model = None
framegen_model = None
def load_deblur_model():
global deblur_model
if deblur_model is not None:
return deblur_model
mode = get_model_mode()
model_paths = get_model_paths()
sys.path.insert(0, MODELS_DIR)
from nafnet_arch import NAFNetLocal
config = NAFNET_CONFIGS['deblur'].get(mode, NAFNET_CONFIGS['deblur']['heavy'])
model = NAFNetLocal(
img_channel=3,
width=config['width'],
middle_blk_num=config['middle_blk_num'],
enc_blk_nums=config['enc_blk_nums'],
dec_blk_nums=config['dec_blk_nums'],
)
deblur_model_path = model_paths['deblur']
checkpoint = torch.load(deblur_model_path, map_location='cpu')
state_dict = checkpoint.get('params', checkpoint.get('state_dict', checkpoint))
for k in list(state_dict.keys()):
if k.startswith('module.'):
state_dict[k[7:]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model = model.to(get_device())
model.eval()
deblur_model = model
print(f"Loaded deblur model: {mode.upper()} (width={config['width']})")
return model
def load_denoise_model():
global denoise_model
if denoise_model is not None:
return denoise_model
mode = get_model_mode()
model_paths = get_model_paths()
sys.path.insert(0, MODELS_DIR)
from nafnet_arch import NAFNet
config = NAFNET_CONFIGS['denoise'].get(mode, NAFNET_CONFIGS['denoise']['heavy'])
model = NAFNet(
img_channel=3,
width=config['width'],
middle_blk_num=config['middle_blk_num'],
enc_blk_nums=config['enc_blk_nums'],
dec_blk_nums=config['dec_blk_nums'],
)
denoise_model_path = model_paths['denoise']
checkpoint = torch.load(denoise_model_path, map_location='cpu')
state_dict = checkpoint.get('params', checkpoint.get('state_dict', checkpoint))
for k in list(state_dict.keys()):
if k.startswith('module.'):
state_dict[k[7:]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model = model.to(get_device())
model.eval()
denoise_model = model
print(f"Loaded denoise model: {mode.upper()} (width={config['width']})")
return model
def load_upscale_model():
global upscale_model
if upscale_model is not None:
return upscale_model
mode = get_model_mode()
model_paths = get_model_paths()
sys.path.insert(0, SCRIPT_DIR)
from sr_arch import RRDBNet, SRVGGNetCompact
if mode == 'heavy':
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
)
model_name = "RealESRGAN-x4plus"
else:
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type='prelu'
)
model_name = "RealESRGAN-general-x4v3"
checkpoint = torch.load(model_paths['upscale'], map_location='cpu')
state_dict = checkpoint.get('params_ema', checkpoint.get('params', checkpoint.get('state_dict', checkpoint)))
for k in list(state_dict.keys()):
if k.startswith('module.'):
state_dict[k[7:]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model = model.to(get_device())
model.eval()
upscale_model = model
print(f"Loaded upscale model: {mode.upper()} ({model_name})")
return model
def load_framegen_model():
global framegen_model
if framegen_model is not None:
return framegen_model
mode = get_model_mode()
model_paths = get_model_paths()
sys.path.insert(0, SCRIPT_DIR)
from rife_arch import RIFE
model = RIFE(mode=mode)
model.load_model(MODELS_DIR, mode=mode)
model.eval()
model.device()
framegen_model = model
version = "4.25" if mode == 'heavy' else "4.17"
print(f"Loaded frame generation model: {mode.upper()} (RIFE v{version})")
return model
def pad_image(img, modulo=32):
h, w = img.shape[2], img.shape[3]
new_h = ((h - 1) // modulo + 1) * modulo
new_w = ((w - 1) // modulo + 1) * modulo
pad_h = new_h - h
pad_w = new_w - w
if pad_h > 0 or pad_w > 0:
img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h))
return img, (h, w)
def process_nafnet(model, img_tensor):
with torch.no_grad():
padded, (h, w) = pad_image(img_tensor)
output = model(padded)
return output[:, :, :h, :w]
def process_upscale(model, img_tensor):
with torch.no_grad():
padded, (h, w) = pad_image(img_tensor, modulo=4)
output = model(padded)
new_h, new_w = h * 4, w * 4
return output[:, :, :new_h, :new_w]
def img2tensor(img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
return img.to(get_device())
def tensor2img(tensor):
tensor = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8)
tensor = cv2.cvtColor(tensor, cv2.COLOR_RGB2BGR)
return tensor
def is_image(path):
return Path(path).suffix.lower() in IMAGE_EXTENSIONS
def is_video(path):
return Path(path).suffix.lower() in VIDEO_EXTENSIONS
def get_files(path):
path = Path(path)
if path.is_file():
return [str(path)]
elif path.is_dir():
files = []
for ext in IMAGE_EXTENSIONS | VIDEO_EXTENSIONS:
files.extend(path.glob(f"*{ext}"))
files.extend(path.glob(f"*{ext.upper()}"))
return sorted([str(f) for f in files])
return []
def parse_multiple_paths(input_string):
paths = []
current = ""
in_quotes = False
quote_char = None
i = 0
while i < len(input_string):
char = input_string[i]
if char in ['"', "'"]:
if not in_quotes:
in_quotes = True
quote_char = char
elif char == quote_char:
in_quotes = False
quote_char = None
else:
current += char
elif char in [' ', '\t'] and not in_quotes:
if current.strip():
paths.append(current.strip())
current = ""
else:
current += char
i += 1
if current.strip():
paths.append(current.strip())
return paths
def categorize_path(path_str):
path = Path(path_str)
if not path.exists():
cleaned = path_str.strip()
if len(cleaned) == 0:
return ('invalid', [])
has_alnum = any(c.isalnum() for c in cleaned)
if not has_alnum:
return ('invalid', [])
return ('not_exist', [])
if path.is_file():
if is_image(str(path)) or is_video(str(path)):
return ('valid', [str(path)])
else:
return ('not_supported', [])
elif path.is_dir():
files = get_files(str(path))
if files:
return ('valid', files)
else:
return ('not_supported', [])
else:
return ('not_supported', [])
def categorize_multiple_paths(path_list):
result = {
'valid': [],
'not_exist': [],
'not_supported': [],
'invalid': [],
'all_valid_files': []
}
for path_str in path_list:
category, valid_files = categorize_path(path_str)
if category == 'valid':
result['valid'].append((path_str, valid_files))
result['all_valid_files'].extend(valid_files)
elif category == 'not_exist':
result['not_exist'].append(path_str)
elif category == 'not_supported':
result['not_supported'].append(path_str)
else:
result['invalid'].append(path_str)
return result
def display_path_summary(categorized, max_display=5):
valid = categorized['valid']
not_exist = categorized['not_exist']
not_supported = categorized['not_supported']
invalid = categorized['invalid']
total_paths = len(valid) + len(not_exist) + len(not_supported) + len(invalid)
if total_paths == 0:
return
print("\n" + "-"*60)
print("INPUT PATH SUMMARY")
print("-"*60)
if valid:
print(f"\n✓ VALID ({len(valid)}):")
display_count = min(len(valid), max_display)
for i in range(display_count):
orig_path, files = valid[i]
if len(files) == 1:
print(f" {orig_path}")
else:
print(f" {orig_path}/ ({len(files)} files)")
if len(valid) > max_display:
remaining = len(valid) - max_display
print(f" ... +{remaining} more valid")
if not_exist:
print(f"\n✗ NOT FOUND ({len(not_exist)}):")
display_count = min(len(not_exist), max_display)
for i in range(display_count):
print(f" {not_exist[i]}")
if len(not_exist) > max_display:
remaining = len(not_exist) - max_display
print(f" ... +{remaining} more not found")
if not_supported:
print(f"\n⚠ NOT SUPPORTED ({len(not_supported)}):")
display_count = min(len(not_supported), max_display)
for i in range(display_count):
print(f" {not_supported[i]}")
if len(not_supported) > max_display:
remaining = len(not_supported) - max_display
print(f" ... +{remaining} more not supported")
if invalid:
print(f"\n? INVALID INPUT ({len(invalid)}):")
display_count = min(len(invalid), max_display)
for i in range(display_count):
print(f" {invalid[i]}")
if len(invalid) > max_display:
remaining = len(invalid) - max_display
print(f" ... +{remaining} more invalid")
valid_file_count = len(categorized['all_valid_files'])
if valid_file_count > 0:
print(f"\n→ {valid_file_count} file(s) ready to process from {len(valid)} valid path(s)")
else:
print(f"\n→ No valid files found to process")
def generate_output_path(input_path, mode, output_arg=None):
input_path = Path(input_path)
if output_arg:
output_path = Path(output_arg)
is_folder = (str(output_path).endswith('/') or
str(output_path).endswith('\\') or
output_path.is_dir() or
(output_path.suffix == '' and not output_path.exists()))
if is_folder:
suffix = get_mode_suffix(mode)
return str(output_path / f"{input_path.stem}{suffix}{input_path.suffix}")
return str(output_arg)
suffix = get_mode_suffix(mode)
parent = input_path.parent
return str(parent / f"{input_path.stem}{suffix}{input_path.suffix}")
def get_mode_suffix(mode):
suffixes = {
'denoise': '_denoised',
'deblur': '_deblurred',
'upscale': '_upscaled',
'clean': '_cleaned',
'full': '_enhanced',
'frame-gen': '_generated',
'clean-frame-gen': '_clean_generated',
'full-frame-gen': '_full_enhanced',
}
return suffixes.get(mode, '_processed')
def ensure_ffmpeg():
if shutil.which('ffmpeg') is None:
raise RuntimeError("ffmpeg not found. Please install ffmpeg to process videos.")
def extract_frames(video_path, output_dir, desc="Extracting frames"):
os.makedirs(output_dir, exist_ok=True)
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
cmd = [
'ffmpeg', '-y', '-i', video_path,
'-vsync', '0',
os.path.join(output_dir, '%08d.png')
]
process = subprocess.Popen(cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
with tqdm(total=total_frames, desc=desc, unit="frames",
bar_format="{l_bar}{bar:30}{r_bar}") as pbar:
while process.poll() is None:
if os.path.exists(output_dir):
current_frames = len([f for f in os.listdir(output_dir) if f.endswith('.png')])
pbar.update(current_frames - pbar.n)
time.sleep(0.1)
if os.path.exists(output_dir):
current_frames = len([f for f in os.listdir(output_dir) if f.endswith('.png')])
pbar.update(current_frames - pbar.n)
def extract_audio(video_path, audio_path):
cmd = [
'ffmpeg', '-y', '-i', video_path,
'-vn', '-acodec', 'copy',
audio_path
]
result = subprocess.run(cmd, capture_output=True)
return result.returncode == 0
def get_video_info(video_path):
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
return fps, frame_count, width, height
def frames_to_video(frames_dir, output_path, fps, audio_path=None, desc="Compiling video"):
frames = sorted([f for f in os.listdir(frames_dir) if f.endswith('.png')])
total_frames = len(frames)
temp_video = output_path + '_temp.mp4'
cmd = [
'ffmpeg', '-y',
'-framerate', str(fps),
'-i', os.path.join(frames_dir, '%08d.png'),
'-c:v', 'libx264',
'-pix_fmt', 'yuv420p',
'-crf', '18',
'-progress', 'pipe:2',
temp_video
]
compiled_frames = [0]
process = subprocess.Popen(cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE, bufsize=1)
def _read_ffmpeg_progress(pipe):
for raw_line in iter(pipe.readline, b''):
decoded = raw_line.decode('utf-8', errors='ignore').strip()
if decoded.startswith('frame='):
try:
compiled_frames[0] = int(decoded.split('=')[1].strip())
except (ValueError, IndexError):
pass
pipe.close()
reader = threading.Thread(target=_read_ffmpeg_progress, args=(process.stderr,), daemon=True)
reader.start()
last_reported = [0]
with tqdm(total=total_frames, desc=desc, unit="frames",
bar_format="{desc} | {n_fmt}/{total_fmt} frames {bar:25} {percentage:5.1f}% | {elapsed}<{remaining}, {rate_fmt}{postfix}") as pbar:
while process.poll() is None:
time.sleep(0.05)
current = compiled_frames[0]
if current > last_reported[0]:
pbar.update(current - last_reported[0])
last_reported[0] = current
remaining = total_frames - current
pbar.set_postfix_str(f"{remaining} remaining")
reader.join(timeout=2)
final = compiled_frames[0]
if final > last_reported[0]:
pbar.update(final - last_reported[0])
last_reported[0] = final
fill_needed = max(0, total_frames - pbar.n)
if fill_needed > 0:
pbar.update(fill_needed)
pbar.set_postfix_str("")
if audio_path and os.path.exists(audio_path):
cmd = [
'ffmpeg', '-y',
'-i', temp_video,
'-i', audio_path,
'-c:v', 'copy',
'-c:a', 'aac',
'-map', '0:v:0',
'-map', '1:a:0?',
output_path
]
result = subprocess.run(cmd, capture_output=True)
if result.returncode == 0 and os.path.exists(output_path):
os.remove(temp_video)
return
if os.path.exists(temp_video):
if os.path.exists(output_path):
os.remove(output_path)
os.rename(temp_video, output_path)
def blend_frames_for_fps(frames_dir, target_fps, original_fps):
ratio = target_fps / original_fps
if ratio <= 1:
return frames_dir
blended_dir = frames_dir + '_blended'
os.makedirs(blended_dir, exist_ok=True)
cmd = [
'ffmpeg', '-y',
'-framerate', str(original_fps),
'-i', os.path.join(frames_dir, '%08d.png'),
'-vf', f'minterpolate=fps={target_fps}:mi_mode=blend',
'-vsync', '0',
os.path.join(blended_dir, '%08d.png')
]
subprocess.run(cmd, capture_output=True)
return blended_dir
def process_image_denoise(img, step_bar=None):
if step_bar:
step_bar.update("Denoising...")
else:
progress.set_step("Denoising")
progress.print_status()
model = load_denoise_model()
tensor = img2tensor(img)
output = process_nafnet(model, tensor)
return tensor2img(output)
def process_image_deblur(img, step_bar=None):
if step_bar:
step_bar.update("Deblurring...")
else:
progress.set_step("Deblurring")
progress.print_status()
model = load_deblur_model()
tensor = img2tensor(img)
output = process_nafnet(model, tensor)
return tensor2img(output)
def process_image_upscale(img, step_bar=None, upscale_factor=4):
if step_bar:
step_bar.update(f"Upscaling x{upscale_factor}...")
else:
progress.set_step(f"Upscaling x{upscale_factor}")
progress.print_status()
model = load_upscale_model()
tensor = img2tensor(img)
output = process_upscale(model, tensor)
result = tensor2img(output)
if upscale_factor == 2:
h, w = result.shape[:2]
new_h, new_w = h // 2, w // 2
result = cv2.resize(result, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
return result
def process_image_clean(img, step_bar=None):
img = process_image_denoise(img, step_bar)
img = process_image_deblur(img, step_bar)
return img
def process_image_full(img, step_bar=None, upscale_factor=4):
img = process_image_denoise(img, step_bar)
img = process_image_deblur(img, step_bar)
img = process_image_upscale(img, step_bar, upscale_factor)
return img
def get_rife_scale():
return 1.0
def get_rife_padding_divisor(scale=1.0):
return max(64, int(64 / scale))
def pad_for_rife(img, scale=1.0):
divisor = get_rife_padding_divisor(scale)
h, w = img.shape[:2]
new_h = ((h - 1) // divisor + 1) * divisor
new_w = ((w - 1) // divisor + 1) * divisor
if new_h > h or new_w > w:
img = np.pad(img, ((0, new_h - h), (0, new_w - w), (0, 0)), mode='edge')
return img, (h, w)
def generate_frames(frames_dir, output_dir, multi=2):
model = load_framegen_model()
frames = sorted([f for f in os.listdir(frames_dir) if f.endswith('.png')])
if len(frames) < 2:
raise ValueError("Need at least 2 frames for generation")
os.makedirs(output_dir, exist_ok=True)
scale = get_rife_scale()
output_idx = 0
total_pairs = len(frames) - 1
total_output_frames = len(frames) * multi - (multi - 1)
pbar = tqdm(total=total_output_frames, desc=f"Generating (x{multi})", unit="frame",
bar_format="{desc} | Frame {n_fmt}/{total_fmt} {bar:25} {percentage:5.1f}% | {elapsed}<{remaining} {postfix}")
for i in range(total_pairs):
pbar.set_postfix_str(f"| pair {i+1}/{total_pairs} [frame {i+1}\u2192{i+2}]")
img0 = cv2.imread(os.path.join(frames_dir, frames[i]))
img1 = cv2.imread(os.path.join(frames_dir, frames[i + 1]))
img0, (orig_h, orig_w) = pad_for_rife(img0, scale)
img1, _ = pad_for_rife(img1, scale)
img0_tensor = img2tensor(img0)
img1_tensor = img2tensor(img1)
cv2.imwrite(os.path.join(output_dir, f'{output_idx:08d}.png'), img0[:orig_h, :orig_w])
output_idx += 1
pbar.update(1)
for j in range(multi - 1):
timestep = (j + 1) / multi
with torch.no_grad():
mid = model.inference(img0_tensor, img1_tensor, timestep, scale)
mid_img = tensor2img(mid)
cv2.imwrite(os.path.join(output_dir, f'{output_idx:08d}.png'), mid_img[:orig_h, :orig_w])
output_idx += 1
pbar.update(1)
last_frame = cv2.imread(os.path.join(frames_dir, frames[-1]))
cv2.imwrite(os.path.join(output_dir, f'{output_idx:08d}.png'), last_frame)
pbar.update(1)
pbar.set_postfix_str("")
return total_output_frames
class _SilentStepBar:
def update(self, msg):
pass
def process_video_frames_step(frames_dir, output_dir, frames, step_name, process_func):
os.makedirs(output_dir, exist_ok=True)
total = len(frames)
silent = _SilentStepBar()
with tqdm(total=total, desc=step_name, unit="frame",
bar_format="{desc} | Frame {n_fmt}/{total_fmt} {bar:25} {percentage:5.1f}% | {elapsed}<{remaining}, {rate_fmt}") as pbar:
for frame_name in frames:
img = cv2.imread(os.path.join(frames_dir, frame_name))
img = process_func(img, step_bar=silent)
cv2.imwrite(os.path.join(output_dir, frame_name), img)
pbar.update(1)
def process_video_multistep(video_path, output_path, steps, audio_path):
original_fps, frame_count, width, height = get_video_info(video_path)
frames_dir = os.path.join(TEMP_DIR, "frames")
current_dir = frames_dir
progress.set_step("Extracting frames")
progress.print_status()
extract_frames(video_path, frames_dir, desc="Extracting frames")
extract_audio(video_path, audio_path)
frames = sorted([f for f in os.listdir(frames_dir) if f.endswith('.png')])
for i, (step_name, process_func) in enumerate(steps):
step_output_dir = os.path.join(TEMP_DIR, f"step_{i}")
progress.set_step(step_name)
progress.print_status()
process_video_frames_step(current_dir, step_output_dir, frames, step_name, process_func)
current_dir = step_output_dir
sample_frame = cv2.imread(os.path.join(current_dir, frames[0]))
new_height, new_width = sample_frame.shape[:2]
progress.set_step("Compiling video")
progress.print_status()
frames_to_video(current_dir, output_path, original_fps,
audio_path if os.path.exists(audio_path) else None,
desc="Compiling video")
return new_width, new_height
def process_video_frame_gen(video_path, output_path, multi=2, fps=None):
ensure_ffmpeg()
original_fps, frame_count, width, height = get_video_info(video_path)
min_fps = original_fps
max_fps = original_fps * multi
if fps is None:
fps = max_fps
elif fps < min_fps:
print(f"Warning: Target FPS {fps:.2f} below minimum ({min_fps:.2f}). Using max: {max_fps:.2f}")
fps = max_fps
elif fps > max_fps:
print(f"Warning: Target FPS {fps:.2f} exceeds maximum ({max_fps:.2f}). Using max: {max_fps:.2f}")
fps = max_fps
frames_dir = os.path.join(TEMP_DIR, "frames")
gen_dir = os.path.join(TEMP_DIR, "generated")
audio_path = os.path.join(TEMP_DIR, "audio.aac")
if os.path.exists(TEMP_DIR):
shutil.rmtree(TEMP_DIR)
os.makedirs(TEMP_DIR, exist_ok=True)
progress.set_step("Extracting frames")
progress.print_status()
extract_frames(video_path, frames_dir, desc="Extracting frames")
extract_audio(video_path, audio_path)
progress.set_step(f"Generating x{multi}")
progress.print_status()
generate_frames(frames_dir, gen_dir, multi)
final_frames_dir = gen_dir
if fps < max_fps:
progress.set_step("Blending frames")
progress.print_status()
final_frames_dir = blend_frames_for_fps(gen_dir, fps, max_fps)
progress.set_step("Compiling video")
progress.print_status()
frames_to_video(final_frames_dir, output_path, fps,
audio_path if os.path.exists(audio_path) else None,
desc="Compiling video")
shutil.rmtree(TEMP_DIR)
def process_video_clean_frame_gen(video_path, output_path, multi=2, fps=None):
ensure_ffmpeg()
original_fps, frame_count, width, height = get_video_info(video_path)
min_fps = original_fps
max_fps = original_fps * multi
if fps is None:
fps = max_fps
elif fps < min_fps:
print(f"Warning: Target FPS {fps:.2f} below minimum ({min_fps:.2f}). Using max: {max_fps:.2f}")
fps = max_fps
elif fps > max_fps:
print(f"Warning: Target FPS {fps:.2f} exceeds maximum ({max_fps:.2f}). Using max: {max_fps:.2f}")
fps = max_fps
frames_dir = os.path.join(TEMP_DIR, "frames")
denoised_dir = os.path.join(TEMP_DIR, "denoised")
cleaned_dir = os.path.join(TEMP_DIR, "cleaned")
gen_dir = os.path.join(TEMP_DIR, "generated")
audio_path = os.path.join(TEMP_DIR, "audio.aac")
if os.path.exists(TEMP_DIR):
shutil.rmtree(TEMP_DIR)
os.makedirs(TEMP_DIR, exist_ok=True)
progress.set_step("Extracting frames")
progress.print_status()
extract_frames(video_path, frames_dir, desc="Extracting frames")
extract_audio(video_path, audio_path)
frames = sorted([f for f in os.listdir(frames_dir) if f.endswith('.png')])
process_video_frames_step(frames_dir, denoised_dir, frames, "Denoising frames", process_image_denoise)
process_video_frames_step(denoised_dir, cleaned_dir, frames, "Deblurring frames", process_image_deblur)
progress.set_step(f"Generating x{multi}")
progress.print_status()
generate_frames(cleaned_dir, gen_dir, multi)
final_frames_dir = gen_dir
if fps < max_fps:
progress.set_step("Blending frames")
progress.print_status()
final_frames_dir = blend_frames_for_fps(gen_dir, fps, max_fps)
progress.set_step("Compiling video")
progress.print_status()
frames_to_video(final_frames_dir, output_path, fps,
audio_path if os.path.exists(audio_path) else None,
desc="Compiling video")
shutil.rmtree(TEMP_DIR)
def process_video_full_frame_gen(video_path, output_path, multi=2, fps=None, upscale_factor=4):
ensure_ffmpeg()
original_fps, frame_count, width, height = get_video_info(video_path)
min_fps = original_fps
max_fps = original_fps * multi
if fps is None:
fps = max_fps
elif fps < min_fps:
print(f"Warning: Target FPS {fps:.2f} below minimum ({min_fps:.2f}). Using max: {max_fps:.2f}")
fps = max_fps
elif fps > max_fps:
print(f"Warning: Target FPS {fps:.2f} exceeds maximum ({max_fps:.2f}). Using max: {max_fps:.2f}")
fps = max_fps
frames_dir = os.path.join(TEMP_DIR, "frames")
denoised_dir = os.path.join(TEMP_DIR, "denoised")
cleaned_dir = os.path.join(TEMP_DIR, "cleaned")
upscaled_dir = os.path.join(TEMP_DIR, "upscaled")
gen_dir = os.path.join(TEMP_DIR, "generated")
audio_path = os.path.join(TEMP_DIR, "audio.aac")
if os.path.exists(TEMP_DIR):
shutil.rmtree(TEMP_DIR)
os.makedirs(TEMP_DIR, exist_ok=True)
progress.set_step("Extracting frames")
progress.print_status()
extract_frames(video_path, frames_dir, desc="Extracting frames")
extract_audio(video_path, audio_path)
frames = sorted([f for f in os.listdir(frames_dir) if f.endswith('.png')])
process_video_frames_step(frames_dir, denoised_dir, frames, "Denoising frames", process_image_denoise)
process_video_frames_step(denoised_dir, cleaned_dir, frames, "Deblurring frames", process_image_deblur)
process_video_frames_step(cleaned_dir, upscaled_dir, frames, f"Upscaling frames (x{upscale_factor})", lambda img, step_bar=None: process_image_upscale(img, step_bar, upscale_factor))
frames = sorted([f for f in os.listdir(upscaled_dir) if f.endswith('.png')])
progress.set_step(f"Generating x{multi}")
progress.print_status()
generate_frames(upscaled_dir, gen_dir, multi)
final_frames_dir = gen_dir
if fps < max_fps:
progress.set_step("Blending frames")
progress.print_status()
final_frames_dir = blend_frames_for_fps(gen_dir, fps, max_fps)
sample_frame = cv2.imread(os.path.join(final_frames_dir, frames[0]))
new_height, new_width = sample_frame.shape[:2]
progress.set_step("Compiling video")
progress.print_status()
frames_to_video(final_frames_dir, output_path, fps,
audio_path if os.path.exists(audio_path) else None,
desc="Compiling video")
shutil.rmtree(TEMP_DIR)
def process_video(video_path, output_path, mode, upscale_factor=4):
ensure_ffmpeg()
mode_steps = {
'denoise': [("Denoising frames", process_image_denoise)],
'deblur': [("Deblurring frames", process_image_deblur)],
'upscale': [(f"Upscaling frames (x{upscale_factor})", lambda img, step_bar=None: process_image_upscale(img, step_bar, upscale_factor))],
'clean': [
("Denoising frames", process_image_denoise),
("Deblurring frames", process_image_deblur),
],
'full': [
("Denoising frames", process_image_denoise),
("Deblurring frames", process_image_deblur),
(f"Upscaling frames (x{upscale_factor})", lambda img, step_bar=None: process_image_upscale(img, step_bar, upscale_factor)),
],
}
if mode not in mode_steps:
raise ValueError(f"Unknown mode: {mode}")
original_fps, frame_count, width, height = get_video_info(video_path)
audio_path = os.path.join(TEMP_DIR, "audio.aac")
if os.path.exists(TEMP_DIR):
shutil.rmtree(TEMP_DIR)
os.makedirs(TEMP_DIR, exist_ok=True)
new_width, new_height = process_video_multistep(
video_path, output_path, mode_steps[mode], audio_path
)
shutil.rmtree(TEMP_DIR)
return new_width, new_height
def process_single_file(input_path, output_path, mode, multi=2, fps=None, upscale_factor=4, show_progress=True):
input_path = Path(input_path)
if not input_path.exists():
raise FileNotFoundError(f"Input not found: {input_path}")
is_img = is_image(input_path)
is_vid = is_video(input_path)
if not is_img and not is_vid:
raise ValueError(f"Unsupported format: {input_path.suffix}")
frame_gen_modes = ['frame-gen', 'clean-frame-gen', 'full-frame-gen']
if mode in frame_gen_modes and is_img:
raise ValueError(f"Frame generation mode '{mode}' only works with videos, not images.")
if output_path is None:
output_path = generate_output_path(input_path, mode)
output_p = Path(output_path)
valid_extensions = IMAGE_EXTENSIONS if is_img else VIDEO_EXTENSIONS
if output_p.suffix.lower() not in valid_extensions:
output_path = str(output_p.with_suffix(input_path.suffix))
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
if is_vid:
if mode == 'frame-gen':
process_video_frame_gen(str(input_path), output_path, multi, fps)
elif mode == 'clean-frame-gen':
process_video_clean_frame_gen(str(input_path), output_path, multi, fps)
elif mode == 'full-frame-gen':
process_video_full_frame_gen(str(input_path), output_path, multi, fps, upscale_factor)
else:
process_video(str(input_path), output_path, mode, upscale_factor)
else:
img = cv2.imread(str(input_path))
if img is None:
raise ValueError(f"Failed to read image: {input_path}")
mode_steps = {
'denoise': 1,
'deblur': 1,
'upscale': 1,
'clean': 2,
'full': 3,
}
num_steps = mode_steps.get(mode, 1)
step_bar = StepProgressBar(num_steps, input_path.name)
process_funcs = {
'denoise': lambda i: process_image_denoise(i, step_bar),
'deblur': lambda i: process_image_deblur(i, step_bar),
'upscale': lambda i: process_image_upscale(i, step_bar, upscale_factor),
'clean': lambda i: process_image_clean(i, step_bar),
'full': lambda i: process_image_full(i, step_bar, upscale_factor),
}
if mode not in process_funcs:
raise ValueError(f"Unknown mode: {mode}")
result = process_funcs[mode](img)
step_bar.finish()
success = cv2.imwrite(output_path, result)
if not success:
raise IOError(f"Failed to write image: {output_path}")
return output_path
def process_multiple_files(input_paths, output_arg, mode, multi=2, fps=None, upscale_factor=4, overwrite=True):
total_files = len(input_paths)
processed = 0
errors = 0
output_folder = None
single_file_output = None
output_files = []
if output_arg:
output_arg = output_arg.strip()
output_path = Path(output_arg)
if output_path.suffix:
if total_files == 1:
single_file_output = output_arg
else:
parent = output_path.parent
if parent and str(parent) != '.':
output_folder = str(parent / output_path.stem)
else:
output_folder = output_path.stem
elif output_arg.endswith('/') or output_arg.endswith('\\'):
output_folder = output_arg
else:
if total_files == 1:
single_file_output = output_arg
else:
output_folder = output_arg
else:
if total_files == 1:
first_input = Path(input_paths[0])
mode_suffix = get_mode_suffix(mode)
single_file_output = str(first_input.parent / f"{first_input.stem}{mode_suffix}{first_input.suffix}")
else:
first_input = Path(input_paths[0])
output_folder = str(first_input.parent / f"{first_input.stem}_{mode}")
progress.start_batch(total_files)
print(f"\n{'='*60}")
print(f"Processing {total_files} file(s) - Mode: {mode}")
print(f"Model mode: {get_model_mode().upper()}")
print(f"{'='*60}\n")
for i, input_path in enumerate(input_paths, 1):
try:
if single_file_output:
output_path = single_file_output
elif output_folder:
input_p = Path(input_path)
mode_suffix = get_mode_suffix(mode)
os.makedirs(output_folder, exist_ok=True)
output_path = os.path.join(output_folder, f"{input_p.stem}{mode_suffix}{input_p.suffix}")
else:
output_path = None
progress.start_file(input_path)
progress.set_step("Loading")
progress.print_status(force=True)
process_single_file(input_path, output_path, mode, multi, fps, upscale_factor)
elapsed = progress.finish_file()
progress.set_step(f"Done ({elapsed:.1f}s)")
progress.print_status(force=True)
progress.print_newline()
processed += 1
if output_path:
output_files.append(output_path)
except Exception as e:
progress.set_step(f"Error: {str(e)[:30]}")
progress.print_status(force=True)
progress.print_newline()
errors += 1
print(f"\n{'='*60}")
print(f"Completed: {processed}/{total_files} files")
print(f"Total time: {progress.get_elapsed_str()}")
if errors > 0:
print(f"Errors: {errors}")
print(f"{'='*60}")
return output_files
def process_file_pairs(file_pairs, mode, multi=2, fps=None, upscale_factor=4, file_type="file"):
total_files = len(file_pairs)
if total_files == 0:
return
processed = 0
errors = 0
progress.start_batch(total_files)
print(f"\n{'='*60}")
print(f"Processing {total_files} {file_type}(s) - Mode: {mode}")
print(f"Model mode: {get_model_mode().upper()}")
print(f"{'='*60}\n")
for i, (input_path, output_path) in enumerate(file_pairs, 1):
try:
progress.start_file(input_path)
progress.set_step("Loading")
progress.print_status(force=True)
process_single_file(input_path, output_path, mode, multi, fps, upscale_factor)
elapsed = progress.finish_file()
progress.set_step(f"Done ({elapsed:.1f}s)")
progress.print_status(force=True)
progress.print_newline()
processed += 1
except Exception as e:
progress.set_step(f"Error: {str(e)[:30]}")
progress.print_status(force=True)
progress.print_newline()
errors += 1
print(f"\n{'='*60}")
print(f"Completed: {processed}/{total_files} {file_type}(s)")
print(f"Total time: {progress.get_elapsed_str()}")
if errors > 0:
print(f"Errors: {errors}")
print(f"{'='*60}")
def show_info():
print("\n" + "="*60)
print("KLARITY - Image/Video Restoration Tool")
print("="*60)
print(f"\nPython: {sys.version.split()[0]}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB")
print(f"\nffmpeg: {'Available' if shutil.which('ffmpeg') else 'NOT FOUND'}")
print(f"\nCurrent model mode: {get_model_mode().upper()}")
print("\nModels (Heavy):")
heavy_paths = get_model_paths_for_mode(SCRIPT_DIR, 'heavy')
for name, path in heavy_paths.items():
status = "✓ Found" if os.path.exists(path) else "✗ Missing"
size = ""
if os.path.exists(path):
size = f" ({os.path.getsize(path) / (1024*1024):.1f} MB)"
print(f" {name}: {status}{size}")
print("\nModels (Lite):")
lite_paths = get_model_paths_for_mode(SCRIPT_DIR, 'lite')
for name, path in lite_paths.items():
status = "✓ Found" if os.path.exists(path) else "✗ Missing"
size = ""
if os.path.exists(path):
size = f" ({os.path.getsize(path) / (1024*1024):.1f} MB)"
print(f" {name}: {status}{size}")
print("\n" + "="*60)
def auto_download_models_for_mode():
current_mode = get_model_mode()
model_paths = get_model_paths()
missing = []
for key, path in model_paths.items():
if not os.path.exists(path) or os.path.getsize(path) < 1000:
missing.append(key)
if not missing:
return True
print(f"\nMissing {current_mode} models detected: {', '.join(missing)}")
print(f"Auto-downloading {current_mode} models...\n")
result = ensure_models(SCRIPT_DIR, model_paths, auto_download=True, prompt=False, mode=current_mode)
if result:
print(f"\nAll {current_mode} models ready!\n")
else:
print(f"\nFailed to download some models. Please check your internet connection.\n")
return result
def download_models_command(mode=None):
print("\n" + "="*60)
if mode:
print(f"Downloading all {mode} models...")
set_model_mode(mode)
else:
print("Downloading all required models...")
print("="*60 + "\n")
current_mode = get_model_mode()
model_paths = get_model_paths()
result = ensure_models(SCRIPT_DIR, model_paths, auto_download=True, prompt=False, mode=current_mode)
if result:
print("\n" + "="*60)
print(f"All {current_mode} models downloaded successfully!")
print("="*60)
else:
print("\n" + "="*60)
print("Some models failed to download. Please check your connection.")
print("="*60)
def interactive_mode():
global device
print("\n" + "="*60)
print("KLARITY - Image/Video Restoration Tool")
print("="*60)
mode = select_model_mode()
set_model_mode(mode)
print(f"\nModel mode set to: {mode.upper()}")
check_and_download_models()
if os.path.exists(TEMP_DIR):
print(f"\nCleaning up leftover temp folder from previous session...")
try:
shutil.rmtree(TEMP_DIR)
print("Temp folder cleaned successfully.")
except Exception as e:
print(f"Warning: Could not clean temp folder: {e}")
image_modes = {
'1': 'denoise',
'2': 'deblur',
'3': 'upscale',
'4': 'clean',
'5': 'full',
}
video_modes = {
'1': 'denoise',
'2': 'deblur',
'3': 'upscale',
'4': 'clean',
'5': 'full',
'6': 'frame-gen',
'7': 'clean-frame-gen',
'8': 'full-frame-gen',
}
device_selected = device is not None
while True:
while True:
print("\n" + "-"*60)
print("INPUT SELECTION")
print("-"*60)
print("\nEnter input path(s) - separate multiple paths with spaces")
print("Use quotes for paths with spaces, or 'q' to quit:")
input_arg = input("> ").strip()
if input_arg.lower() in ['q', 'quit', 'exit']:
print("\nExiting Klarity. Goodbye!")
return
if not input_arg:
print("Error: No input provided. Please enter at least one path or 'q' to quit.")
continue
parsed_paths = parse_multiple_paths(input_arg)
if not parsed_paths:
print("Error: Could not parse any paths from input. Please try again.")
continue
categorized = categorize_multiple_paths(parsed_paths)
display_path_summary(categorized)
input_paths = categorized['all_valid_files']
if not input_paths:
print("\n" + "-"*40)
has_errors = categorized['not_exist'] or categorized['not_supported'] or categorized['invalid']
if has_errors:
print("Options:")
print(" 1. Enter different paths")
print(" 2. Exit")
retry_choice = input("\nSelect option (1 or 2): ").strip()
if retry_choice == '1':
continue
else:
print("\nExiting Klarity. Goodbye!")
return
else:
print("Please enter valid paths.")
continue
else:
break
input_entries = []
for orig_path, valid_files in categorized['valid']:
entry = {
'original_path': orig_path,
'is_folder': Path(orig_path).is_dir(),
'images': [f for f in valid_files if is_image(f)],
'videos': [f for f in valid_files if is_video(f)],
'name': Path(orig_path).name if Path(orig_path).is_dir() else Path(orig_path).stem
}
input_entries.append(entry)
input_entries.sort(key=lambda x: x['name'].lower())
total_image_inputs = sum(1 for e in input_entries if e['images'])
total_video_inputs = sum(1 for e in input_entries if e['videos'])
total_images = sum(len(e['images']) for e in input_entries)
total_videos = sum(len(e['videos']) for e in input_entries)
print(f"\n" + "-"*60)
print("FILES TO PROCESS")
print("-"*60)
print(f" Total inputs: {len(input_entries)}")
print(f" Images: {total_images} (from {total_image_inputs} input{'s' if total_image_inputs != 1 else ''})")
print(f" Videos: {total_videos} (from {total_video_inputs} input{'s' if total_video_inputs != 1 else ''})")
process_images = False
process_videos = False
while True:
if total_images > 0:
choice = input("\nProcess images? (y/n): ").strip().lower()
if choice in ['y', 'yes']:
process_images = True
else:
print(" → Images ignored")
if total_videos > 0:
choice = input("\nProcess videos? (y/n): ").strip().lower()
if choice in ['y', 'yes']:
process_videos = True
else:
print(" → Videos ignored")
if not process_images and not process_videos:
print("\nError: Nothing selected to process.")
retry = input("Would you like to try again? (y/n): ").strip().lower()
if retry in ['y', 'yes']:
continue
else:
print("\nExiting Klarity. Goodbye!")
return
else:
break
image_mode = None
image_upscale_factor = 4
if process_images:
print("\n" + "-"*60)
print("IMAGE SETTINGS")
print("-"*60)
print("\nSelect mode for images:")
print(" 1. Denoise (remove noise)")
print(" 2. Deblur (remove blur)")
print(" 3. Upscale (4x upscaling)")
print(" 4. Clean (denoise + deblur)")
print(" 5. Full (denoise + deblur + upscale)")
while True:
choice = input("\nSelect mode (1-5): ").strip()
if choice in image_modes:
image_mode = image_modes[choice]
break
else:
print(f"Error: Invalid choice '{choice}'. Please enter a number from 1 to 5.")
if image_mode in ['upscale', 'full']:
print("\nUpscale factor options:")
print(" 4 = 4x upscale (default)")
print(" 2 = 2x upscale")
while True:
upscale_choice = input("Select upscale factor (2 or 4, default 4): ").strip()
if upscale_choice == '' or upscale_choice == '4':
image_upscale_factor = 4
break
elif upscale_choice == '2':
image_upscale_factor = 2
break
else:
print(f"Error: Invalid choice '{upscale_choice}'. Please enter 2 or 4.")
video_mode = None
multi = 2
fps = None
video_upscale_factor = 4
if process_videos:
print("\n" + "-"*60)
print("VIDEO SETTINGS")
print("-"*60)
print("\nSelect mode for videos:")
print(" 1. Denoise (remove noise)")
print(" 2. Deblur (remove blur)")
print(" 3. Upscale (4x upscaling)")
print(" 4. Clean (denoise + deblur)")
print(" 5. Full (denoise + deblur + upscale)")
print(" 6. Frame Generation (interpolate video frames)")
print(" 7. Clean + Frame Generation")
print(" 8. Full + Frame Generation")
while True:
choice = input("\nSelect mode (1-8): ").strip()
if choice in video_modes:
video_mode = video_modes[choice]
break
else:
print(f"Error: Invalid choice '{choice}'. Please enter a number from 1 to 8.")
if video_mode in ['upscale', 'full']:
print("\nUpscale factor options:")
print(" 4 = 4x upscale (default)")
print(" 2 = 2x upscale")
while True:
upscale_choice = input("Select upscale factor (2 or 4, default 4): ").strip()
if upscale_choice == '' or upscale_choice == '4':
video_upscale_factor = 4
break
elif upscale_choice == '2':
video_upscale_factor = 2
break
else:
print(f"Error: Invalid choice '{upscale_choice}'. Please enter 2 or 4.")
if video_mode == 'full-frame-gen':
print("\nUpscale factor options:")
print(" 4 = 4x upscale (default)")
print(" 2 = 2x upscale")
while True:
upscale_choice = input("Select upscale factor (2 or 4, default 4): ").strip()
if upscale_choice == '' or upscale_choice == '4':
video_upscale_factor = 4
break
elif upscale_choice == '2':
video_upscale_factor = 2
break
else:
print(f"Error: Invalid choice '{upscale_choice}'. Please enter 2 or 4.")
if video_mode in ['frame-gen', 'clean-frame-gen', 'full-frame-gen']:
print("\nFrame multiplier options:")
print(" 2 = double the frame rate")
print(" 4 = quadruple the frame rate")
while True:
multi_choice = input("Select multiplier (2 or 4, default 2): ").strip()
if multi_choice == '' or multi_choice == '2':
multi = 2
break
elif multi_choice == '4':
multi = 4
break
else:
print(f"Error: Invalid choice '{multi_choice}'. Please enter 2 or 4.")
single_video_inputs = [e for e in input_entries if e['videos']]
video_count = sum(len(e['videos']) for e in single_video_inputs)
if video_count == 1:
video_path = single_video_inputs[0]['videos'][0]
original_fps, _, _, _ = get_video_info(video_path)
min_fps = original_fps
max_fps = original_fps * multi
fps_input = input(f"Target FPS [{min_fps:.2f}/{max_fps:.2f}] (press Enter for max): ").strip()
if fps_input:
try:
fps = float(fps_input)
except ValueError:
print("Invalid FPS, using auto max")
else:
fps_input = input("Target FPS (press Enter for auto max): ").strip()
if fps_input:
try:
fps = float(fps_input)
except ValueError:
print("Invalid FPS, using auto max")
print("\n" + "-"*60)
print("OUTPUT SETTINGS")
print("-"*60)
print("\nPress Enter for auto-default, or enter custom path.")
print("For folders: outputs to a new folder with processed files.")
print("For files: outputs next to original with suffix.\n")
image_input_idx = 0
video_input_idx = 0
input_outputs = {}
for entry in input_entries:
input_path = entry['original_path']
is_folder = entry['is_folder']
input_name = entry['name']
images = entry['images'] if process_images else []
videos = entry['videos'] if process_videos else []
input_outputs[input_path] = {'images': None, 'videos': None}
if images:
image_input_idx += 1
image_count = len(images)
tracker = f"images [{image_input_idx}/{total_image_inputs}]"
display_name = input_name if is_folder else Path(input_path).name
if is_folder:
default_output = str(Path(input_path).parent / f"{input_name}_{image_mode}")
print(f"\n{tracker} \"{display_name}/\" ({image_count} images)")
print(f" Auto: {default_output}/")
else:
if image_count == 1:
default_output = str(Path(input_path).parent / f"{input_name}_{image_mode}{Path(input_path).suffix}")
print(f"\n{tracker} \"{display_name}\"")
print(f" Auto: {default_output}")
else:
default_output = str(Path(input_path).parent / f"{input_name}_{image_mode}")
print(f"\n{tracker} \"{display_name}\" ({image_count} images)")
print(f" Auto: {default_output}/")
user_output = input("> ").strip()
if user_output:
if user_output.startswith('"') and user_output.endswith('"'):
user_output = user_output[1:-1]
input_outputs[input_path]['images'] = user_output
else:
input_outputs[input_path]['images'] = default_output
if videos:
video_input_idx += 1
video_count = len(videos)
tracker = f"videos [{video_input_idx}/{total_video_inputs}]"
display_name = input_name if is_folder else Path(input_path).name
if is_folder:
default_output = str(Path(input_path).parent / f"{input_name}_{video_mode}")
print(f"\n{tracker} \"{display_name}/\" ({video_count} videos)")
print(f" Auto: {default_output}/")
else:
if video_count == 1:
default_output = str(Path(input_path).parent / f"{input_name}_{video_mode}{Path(input_path).suffix}")
print(f"\n{tracker} \"{display_name}\"")
print(f" Auto: {default_output}")
else:
default_output = str(Path(input_path).parent / f"{input_name}_{video_mode}")
print(f"\n{tracker} \"{display_name}\" ({video_count} videos)")
print(f" Auto: {default_output}/")
user_output = input("> ").strip()
if user_output:
if user_output.startswith('"') and user_output.endswith('"'):
user_output = user_output[1:-1]
input_outputs[input_path]['videos'] = user_output
else:
input_outputs[input_path]['videos'] = default_output
if not device_selected:
print("\n" + "-"*60)
print("DEVICE SELECTION")
print("-"*60)
select_device()
device_selected = True
else:
print(f"\nUsing previously selected device: {device}")
image_pairs = []
video_pairs = []
for entry in input_entries:
input_path = entry['original_path']
images = entry['images'] if process_images else []
videos = entry['videos'] if process_videos else []
if images:
user_output = input_outputs[input_path]['images']
for img_file in images:
output = generate_output_path(img_file, image_mode, output_arg=user_output)
image_pairs.append((img_file, output))
if videos:
user_output = input_outputs[input_path]['videos']
for vid_file in videos:
output = generate_output_path(vid_file, video_mode, output_arg=user_output)
video_pairs.append((vid_file, output))
if image_pairs:
process_file_pairs(image_pairs, image_mode, multi=2, fps=None, upscale_factor=image_upscale_factor, file_type="image")
if video_pairs:
process_file_pairs(video_pairs, video_mode, multi, fps, upscale_factor=video_upscale_factor, file_type="video")
print("\n" + "="*60)
print("ALL PROCESSING COMPLETE!")
print("="*60)
print("\nWhat would you like to do next?")
print(" 1. Process again (start fresh)")
print(" 2. Exit")
post_choice = input("\nSelect option (1 or 2): ").strip()
if post_choice == '1':
print("\nStarting fresh session...")
if os.path.exists(TEMP_DIR):
try:
shutil.rmtree(TEMP_DIR)
except Exception:
pass
continue
else:
print("\nExiting Klarity. Goodbye!")
return
def main():
global JSON_PROGRESS
parser = argparse.ArgumentParser(
description="KLARITY - Image/Video Restoration Tool",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python klarity.py # Launch GUI (default)
python klarity.py cli # Interactive CLI mode
python klarity.py denoise image.jpg # Denoise image
python klarity.py -lite denoise image.jpg # Denoise with lite model
python klarity.py -heavy upscale video.mp4 # Upscale with heavy model
python klarity.py frame-gen video.mp4 --multi 2 # Frame interpolation
"""
)
parser.add_argument('command', nargs='?', default='gui',
help="Command: gui (default), cli, denoise, deblur, upscale, clean, full, frame-gen, clean-frame-gen, full-frame-gen, info, download-models")
parser.add_argument('input', nargs='?', help="Input file or folder")
parser.add_argument('-o', '--output', help="Output file or folder")
parser.add_argument('--multi', type=int, choices=[2, 4], default=2,
help="Frame multiplier for frame-gen (2 or 4)")
parser.add_argument('--upscale', type=int, choices=[2, 4], default=4,
help="Upscale factor (2 or 4, default 4)")
parser.add_argument('--fps', type=float, help="Target FPS for frame generation")
parser.add_argument('--scale', type=float, choices=[0.5, 1.0, 2.0], default=1.0,
help="RIFE scale factor")
parser.add_argument('--device', choices=['cpu', 'gpu', 'auto'], default='auto',
help="Device to use (cpu, gpu, or auto)")
parser.add_argument('--cpu', action='store_true', help="Force CPU (legacy, same as --device cpu)")
parser.add_argument('-heavy', action='store_true', help="Use heavy models (default, better quality)")
parser.add_argument('-lite', action='store_true', help="Use lite models (faster, smaller)")
parser.add_argument('--json-progress', action='store_true', help="Output progress as JSON (for GUI)")
args = parser.parse_args()
JSON_PROGRESS = args.json_progress
if args.lite and args.heavy:
print("Error: Cannot specify both -lite and -heavy flags")
return
if args.lite:
set_model_mode('lite')
elif args.heavy:
set_model_mode('heavy')
else:
set_model_mode('heavy')
if os.path.exists(TEMP_DIR):
try:
shutil.rmtree(TEMP_DIR)
if not JSON_PROGRESS:
print("Cleaned up leftover temporary folder.")
except Exception as e:
if not JSON_PROGRESS:
print(f"Warning: Could not clean temp folder: {e}")
if args.command == 'gui' or args.command is None:
try:
script_dir = os.path.dirname(os.path.abspath(__file__))
gui_path = os.path.join(script_dir, "gui.py")
if os.path.exists(gui_path):
from gui import main as gui_main
gui_main()
else:
print("Error: gui.py not found. Please ensure gui.py is in the same directory.")
print("Falling back to CLI mode...")
interactive_mode()
except ImportError as e:
print(f"Error: Could not import GUI: {e}")
print("Make sure PyQt5 is installed: pip install PyQt5")
print("Falling back to CLI mode...")
interactive_mode()
return
if args.command == 'cli':
interactive_mode()
return
if args.command == 'info':
show_info()
return
if args.command == 'download-models':
download_models_command()
return
if not args.input:
print("Error: Input path required")
print("Usage: python klarity.py <command> <input> [-o output]")
print("\nCommands: denoise, deblur, upscale, clean, full, frame-gen, clean-frame-gen, full-frame-gen")
print("\nModel modes: -heavy (default), -lite")
return
if not auto_download_models_for_mode():
print("Cannot continue without required models.")
return
force_cpu = args.cpu or args.device == 'cpu'
device_type = None if args.device == 'auto' else args.device
get_device(force_cpu=force_cpu, device_type=device_type)
input_paths = get_files(args.input)
if not input_paths:
if os.path.exists(args.input):
input_paths = [args.input]
else:
print(f"Error: Input not found: {args.input}")
return
result = process_multiple_files(input_paths, args.output, args.command, args.multi, args.fps, args.upscale)
if JSON_PROGRESS and result:
output_file = result[0] if isinstance(result, list) else result
json_output = json.dumps({
'percent': 100,
'step': 'Complete',
'output': output_file
})
print(json_output)
sys.stdout.flush()
if __name__ == '__main__':
main()