Spaces:
Sleeping
Sleeping
| # Pip-Packages ----------------------------------------------------- | |
| import importlib | |
| import os | |
| import sys | |
| from datetime import datetime | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch import optim | |
| from torch.utils.data import DataLoader | |
| # From local package ----------------------------------------------- | |
| from disvae.models.losses import get_loss_f | |
| from disvae.models.vae import init_specific_model | |
| from disvae.training import Trainer | |
| from disvae.utils.modelIO import save_model | |
| # Loss stuff: | |
| def parse_losses(p_model, filename="train_losses.log"): | |
| df = pd.read_csv(Path(p_model) / filename) | |
| losses = df["Loss"].unique() | |
| rtn = [np.array(df[df["Loss"] == l]["Value"]) for l in losses] | |
| rtn = pd.DataFrame(np.array(rtn).T, columns=losses) | |
| return rtn | |
| def get_kl_loss_latent(df): | |
| """df muss bereits geparsed sein!""" | |
| rtn = {int(c.split("_")[-1]): df[c].iloc[-1] for c in df if "kl_loss_" in c} | |
| rtn = dict(sorted(rtn.items(), key=lambda item: item[1], reverse=True)) | |
| return rtn | |
| def get_kl_dict(p_model): | |
| df = parse_losses(p_model) | |
| rtn = get_kl_loss_latent(df) | |
| return rtn | |
| # Datalaader convinience stuff | |
| # def get_dataloader(dataset: torch.data.Dataset, batch_size, num_workers): | |
| # # Funktion ist recht kompliziert. Das geht im Notebook schnell | |
| # # Diese Dinge werden auch zur Visualisierung des Datasets benötigt | |
| # # p_dataset_module, dataset_class, dataset_args | |
| # # Import module | |
| # # if p_dataset_module not in sys.path: | |
| # # sys.path.append(str(Path(p_dataset_module).parent)) | |
| # # Dataset = getattr( | |
| # # importlib.import_module(Path(p_dataset_module).stem), dataset_class | |
| # # ) | |
| # # # Ab hier an, wenn das normal importiert würde | |
| # # ds = Dataset(**dataset_args) | |
| # | |
| # return loader | |
| def get_export_dir(base_dir: str, folder_name): | |
| if folder_name is None: | |
| folder_name = "Model_" + ( | |
| datetime.now().replace(microsecond=0).isoformat() | |
| ).replace(" ", "_").replace(":", "-") | |
| rtn = Path(base_dir) / folder_name | |
| if not rtn.exists(): | |
| os.makedirs(rtn) | |
| else: | |
| raise ValueError("Output directory already exists.") | |
| return rtn | |
| def train_model(model, data_loader, loss_f, device, lr, epochs, export_dir): | |
| trainer = Trainer( | |
| model, | |
| optim.Adam(model.parameters(), lr=lr), | |
| loss_f, | |
| device=device, | |
| # logger=logger, | |
| save_dir=export_dir, | |
| is_progress_bar=True, | |
| ) # , | |
| # gif_visualizer=gif_visualizer) | |
| trainer(data_loader, epochs=epochs, checkpoint_every=10) | |
| save_model(trainer.model, export_dir) | |
| # , metadata=config) # Speichern passiert auch schon vorher | |
| # gif_visualizer = GifTraversalsTraining(model, args.dataset, exp_dir) | |
| def train(dataset, config) -> str: | |
| # Validate Config? | |
| print("1) Set device") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Device:\t\t {device}") | |
| print("2) Get dataloader") | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=config["data_params"]["batch_size"], | |
| shuffle=True, | |
| pin_memory=torch.cuda.is_available, | |
| num_workers=config["data_params"]["num_workers"], | |
| ) | |
| print("3) Build model") | |
| img_size = list(dataloader.dataset[0][0].shape) | |
| print(f"Image size: \t {img_size}") | |
| model = init_specific_model(img_size=img_size, **config["model_params"]) | |
| model = model.to(device) # make sure trainer and viz on same device | |
| print("4) Build loss function") | |
| loss_f = get_loss_f( | |
| n_data=len(dataloader.dataset), device=device, **config["loss_params"] | |
| ) | |
| print("5) Parse Export Params") | |
| export_dir = get_export_dir(**config["export_params"]) | |
| print("6) Training model") | |
| train_model( | |
| model=model, | |
| data_loader=dataloader, | |
| loss_f=loss_f, | |
| device=device, | |
| export_dir=export_dir, | |
| **config["trainer_params"], | |
| ) | |
| return export_dir | |