unionpoint's picture
Upload folder using huggingface_hub
5d2fa0b verified
import math
import os
import random
import shutil
from pathlib import Path
import numpy as np
import torch
from omegaconf import OmegaConf
from PIL import Image
class EarlyStopping:
def __init__(self, patience=7, mode="max"):
self.patience = patience
self.mode = mode
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, metric_value):
score = -metric_value if self.mode == "min" else metric_value
if self.best_score is None:
self.best_score = score
return True
elif score < self.best_score:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return False
else:
self.best_score = score
self.counter = 0
return True
class CosineAnnealingWarmupLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0, last_epoch=-1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr = min_lr
self.min_lr_ratios = []
for group in optimizer.param_groups:
ratio = min_lr / max(group["lr"], 1e-12)
self.min_lr_ratios.append(ratio)
super().__init__(optimizer, last_epoch)
def get_lr(self):
curr_step = self.last_epoch
# linear warmup phase
if curr_step < self.warmup_steps:
scale = curr_step / max(1, self.warmup_steps)
return [base_lr * scale for base_lr in self.base_lrs]
# cosine annealing phase
progress = (curr_step - self.warmup_steps) / max(
1, self.total_steps - self.warmup_steps
)
progress = min(1.0, max(0.0, progress))
cosine = 0.5 * (1 + math.cos(math.pi * progress))
return [
base_lr * (ratio + (1 - ratio) * cosine)
for base_lr, ratio in zip(self.base_lrs, self.min_lr_ratios)
]
def set_seed(seed=42, deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_config(config_path):
return OmegaConf.load(config_path)
def save_checkpoint(state, is_best, checkpoint_dir, filename="last.pt"):
os.makedirs(checkpoint_dir, exist_ok=True)
epoch = state["epoch"]
filename = f"checkpoint_epoch_{epoch}.pt"
filepath = os.path.join(checkpoint_dir, filename)
torch.save(state, filepath)
last_path = os.path.join(checkpoint_dir, "last.pt")
shutil.copyfile(filepath, last_path)
if is_best:
best_path = os.path.join(checkpoint_dir, "best.pt")
shutil.copyfile(filepath, best_path)
def check_dataset(data_dir):
data_path = Path(data_dir)
corrupt_files = []
print(f"Checking images in {data_dir}...")
for img_path in data_path.glob("**/*"):
if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"]:
try:
with Image.open(img_path) as img:
img.verify()
except Exception as e:
print(f"CORRUPT: {img_path} | Error: {e}")
corrupt_files.append(img_path)
if corrupt_files:
print(f"\nFound {len(corrupt_files)} corrupted files.")
else:
print("Dataset is clean")