Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, Dataset, random_split | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from datasets import load_dataset | |
| from PIL import Image | |
| import numpy as np | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| import io | |
| # Definiamo un percorso per salvare il modello addestrato | |
| MODEL_PATH = "sprite_generator_model" | |
| os.makedirs(MODEL_PATH, exist_ok=True) | |
| # Carichiamo il dataset da Hugging Face | |
| print("Caricamento del dataset...") | |
| dataset = load_dataset("pawkanarek/spraix_1024") | |
| print(f"Dataset caricato. Dimensioni: {len(dataset['train'])} esempi di training") | |
| # Verifichiamo gli split disponibili | |
| print("Split disponibili nel dataset:") | |
| print(dataset.keys()) | |
| # Stampiamo un esempio per capire la struttura del dataset | |
| print("Esempio di dato dal dataset:") | |
| example = dataset['train'][0] | |
| print("Chiavi disponibili:", example.keys()) | |
| for key in example: | |
| print(f"{key}: {type(example[key])}") | |
| # Se il valore è un dizionario, stampiamo anche le sue chiavi | |
| if isinstance(example[key], dict): | |
| print(f" Sottochavi: {example[key].keys()}") | |
| # Classe per il nostro dataset personalizzato | |
| class SpriteDataset(Dataset): | |
| def __init__(self, dataset_to_use, max_length=128): | |
| self.dataset = dataset_to_use | |
| self.tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
| self.max_length = max_length | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.ConvertImageDtype(torch.float), # Converti in float32 | |
| transforms.Lambda(lambda image: image[:3, :, :]), # Seleziona solo i primi 3 canali (RGB) | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = self.dataset[idx] | |
| # Estrai informazioni dalla descrizione completa | |
| description = item['text'] if 'text' in item else "" | |
| # Estrai numero di frame dal testo | |
| num_frames = 1 # valore di default | |
| if "frame" in description: | |
| # Cerca numeri seguiti da "frame" nel testo | |
| import re | |
| frames_match = re.search(r'(\d+)-frame', description) | |
| if frames_match: | |
| num_frames = int(frames_match.group(1)) | |
| # Prepara il testo per il modello | |
| text_input = f""" | |
| Description: {description} | |
| Number of frames: {num_frames} | |
| """ | |
| # Tokenizziamo l'input testuale | |
| encoded_text = self.tokenizer( | |
| text_input, | |
| padding="max_length", | |
| max_length=self.max_length, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| # Prepariamo l'immagine (o le immagini se ci sono frame multipli) | |
| sprite_frames = [] | |
| # Controlla le chiavi disponibili per i frame | |
| if 'image' in item: | |
| # Se c'è un'unica immagine | |
| img = item['image'] | |
| if isinstance(img, dict) and 'bytes' in img: | |
| img_pil = Image.open(io.BytesIO(img['bytes'])) | |
| sprite_frames.append(self.transform(img_pil)) | |
| elif hasattr(img, 'convert'): # Se è già un'immagine PIL | |
| sprite_frames.append(self.transform(img)) | |
| else: | |
| # Prova a cercare frame_0, frame_1, ecc. | |
| for frame in range(num_frames): | |
| frame_key = f'frame_{frame}' | |
| if frame_key in item: | |
| img = item[frame_key] | |
| if isinstance(img, dict) and 'bytes' in img: | |
| img_pil = Image.open(io.BytesIO(img['bytes'])) | |
| sprite_frames.append(self.transform(img_pil)) | |
| elif hasattr(img, 'convert'): # Se è già un'immagine PIL | |
| sprite_frames.append(self.transform(img)) | |
| # Se non abbiamo trovato immagini, prova a cercare altre chiavi comuni | |
| if not sprite_frames: | |
| possible_image_keys = ['image', 'img', 'sprite', 'frames'] | |
| for key in possible_image_keys: | |
| if key in item and item[key] is not None: | |
| img = item[key] | |
| if isinstance(img, dict) and 'bytes' in img: | |
| img_pil = Image.open(io.BytesIO(img['bytes'])) | |
| sprite_frames.append(self.transform(img_pil)) | |
| elif hasattr(img, 'convert'): # Se è già un'immagine PIL | |
| sprite_frames.append(self.transform(img)) | |
| break | |
| # Se ancora non abbiamo frame, crea un tensore vuoto | |
| if not sprite_frames: | |
| sprite_frames.append(torch.zeros((3, 256, 256))) | |
| # Combiniamo tutti i frame in un unico tensore | |
| sprite_tensor = torch.stack(sprite_frames) | |
| return { | |
| "input_ids": encoded_text.input_ids.squeeze(), | |
| "attention_mask": encoded_text.attention_mask.squeeze(), | |
| "sprite_frames": sprite_tensor, | |
| "num_frames": torch.tensor(num_frames, dtype=torch.int64) | |
| } | |
| # Modello generatore di sprite | |
| class SpriteGenerator(nn.Module): | |
| def __init__(self, text_encoder_name="t5-base", latent_dim=512): | |
| super(SpriteGenerator, self).__init__() | |
| # Encoder testuale | |
| self.text_encoder = AutoModelForSeq2SeqLM.from_pretrained(text_encoder_name) | |
| # Freeziamo i parametri dell'encoder per iniziare | |
| for param in self.text_encoder.parameters(): | |
| param.requires_grad = False | |
| # Proiezione dal testo al latent space | |
| self.text_projection = nn.Sequential( | |
| nn.Linear(self.text_encoder.config.d_model, latent_dim), | |
| nn.LeakyReLU(0.2), | |
| nn.Linear(latent_dim, latent_dim) | |
| ) | |
| # Frame generator (una rete deconvoluzionale) | |
| self.generator = nn.Sequential( | |
| # Input: latent_dim x 1 x 1 | |
| nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), # -> 512 x 4 x 4 | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(True), | |
| nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), # -> 256 x 8 x 8 | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(True), | |
| nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), # -> 128 x 16 x 16 | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(True), | |
| nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), # -> 64 x 32 x 32 | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(True), | |
| nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), # -> 32 x 64 x 64 | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(True), | |
| nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False), # -> 16 x 128 x 128 | |
| nn.BatchNorm2d(16), | |
| nn.ReLU(True), | |
| nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False), # -> 3 x 256 x 256 | |
| nn.Tanh() | |
| ) | |
| # Frame interpolator per supportare animazioni con più frame | |
| self.frame_interpolator = nn.Sequential( | |
| nn.Linear(latent_dim + 1, latent_dim), # +1 per l'informazione sul frame | |
| nn.LeakyReLU(0.2), | |
| nn.Linear(latent_dim, latent_dim), | |
| nn.LeakyReLU(0.2) | |
| ) | |
| def forward(self, input_ids, attention_mask, num_frames=1): | |
| batch_size = input_ids.shape[0] | |
| # Codifichiamo il testo | |
| text_outputs = self.text_encoder.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True | |
| ) | |
| # Utilizziamo l'ultimo hidden state | |
| text_features = text_outputs.last_hidden_state.mean(dim=1) # Media per ottenere un vettore per esempio | |
| # Proiettiamo nello spazio latente | |
| latent_vector = self.text_projection(text_features) | |
| # Generiamo frame multipli se necessario | |
| all_frames = [] | |
| for frame_idx in range(max(num_frames.max().item(), 1)): | |
| # Normalizziamo l'indice del frame | |
| frame_info = torch.ones((batch_size, 1), device=latent_vector.device) * frame_idx / max(num_frames.max().item(), 1) | |
| # Combiniamo il vettore latente con l'informazione sul frame | |
| frame_latent = self.frame_interpolator( | |
| torch.cat([latent_vector, frame_info], dim=1) | |
| ) | |
| # Ricordiamo quanti frame generare per ogni esempio del batch | |
| frame_mask = (frame_idx < num_frames).float().unsqueeze(1) | |
| # Riformattiamo per il generatore | |
| frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3) # [B, latent_dim, 1, 1] | |
| # Generiamo il frame | |
| frame = self.generator(frame_latent_reshaped) * frame_mask.unsqueeze(2).unsqueeze(3) | |
| all_frames.append(frame) | |
| # Combiniamo tutti i frame | |
| sprites = torch.stack(all_frames, dim=1) # [B, num_frames, 3, 256, 256] | |
| return sprites | |
| # Funzione per addestrare il modello | |
| def train_model(model, train_loader, val_loader, epochs=10, lr=0.0002): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Utilizzo del dispositivo: {device}") | |
| model = model.to(device) | |
| optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999)) | |
| criterion = nn.MSELoss() | |
| best_val_loss = float('inf') | |
| for epoch in range(epochs): | |
| # Training | |
| model.train() | |
| train_loss = 0.0 | |
| for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"): | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| target_sprites = batch["sprite_frames"].to(device) | |
| num_frames = batch["num_frames"].to(device) | |
| optimizer.zero_grad() | |
| # Forward pass | |
| output_sprites = model(input_ids, attention_mask, num_frames) | |
| # Calcoliamo la loss per il batch | |
| loss = 0.0 | |
| for i in range(len(num_frames)): | |
| # Utilizziamo solo i frame validi per ogni esempio | |
| valid_frames = min(output_sprites.shape[1], target_sprites.shape[1], num_frames[i].item()) | |
| if valid_frames > 0: | |
| loss += criterion( | |
| output_sprites[i, :valid_frames], | |
| target_sprites[i, :valid_frames] | |
| ) | |
| loss = loss / len(num_frames) # Media per batch | |
| # Backward pass | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| train_loss /= len(train_loader) | |
| # Validation | |
| model.eval() | |
| val_loss = 0.0 | |
| with torch.no_grad(): | |
| for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} - Validation"): | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| target_sprites = batch["sprite_frames"].to(device) | |
| num_frames = batch["num_frames"].to(device) | |
| output_sprites = model(input_ids, attention_mask, num_frames) | |
| # Calcoliamo la loss per il batch di validazione | |
| loss = 0.0 | |
| for i in range(len(num_frames)): | |
| valid_frames = min(output_sprites.shape[1], target_sprites.shape[1], num_frames[i].item()) | |
| if valid_frames > 0: | |
| loss += criterion( | |
| output_sprites[i, :valid_frames], | |
| target_sprites[i, :valid_frames] | |
| ) | |
| loss = loss / len(num_frames) | |
| val_loss += loss.item() | |
| val_loss /= len(val_loader) | |
| print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") | |
| # Salviamo il modello migliore | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| torch.save(model.state_dict(), os.path.join(MODEL_PATH, "best_model.pth")) | |
| print(f"Modello salvato con Val Loss: {val_loss:.4f}") | |
| # Salviamo il modello finale | |
| torch.save(model.state_dict(), os.path.join(MODEL_PATH, "Animator2D-v2.pth")) | |
| print(f"Addestramento completato. Modello finale salvato.") | |
| return model | |
| # Codice per l'esecuzione dell'addestramento | |
| if __name__ == "__main__": | |
| # Dividiamo il dataset in train e validation manualmente | |
| # dato che abbiamo solo lo split "train" | |
| train_size = int(0.8 * len(dataset['train'])) # 80% per training | |
| val_size = len(dataset['train']) - train_size # 20% per validation | |
| print(f"Dividendo il dataset: {train_size} esempi per training, {val_size} esempi per validation") | |
| # Creiamo i subset | |
| train_subset, val_subset = random_split( | |
| dataset['train'], | |
| [train_size, val_size] | |
| ) | |
| # Creiamo i dataset personalizzati | |
| train_dataset = SpriteDataset(train_subset) | |
| val_dataset = SpriteDataset(val_subset) | |
| print(f"Dataset creati: {len(train_dataset)} esempi di training, {len(val_dataset)} esempi di validation") | |
| # Creiamo i dataloader | |
| train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4) | |
| val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4) | |
| # Creiamo e addestriamo il modello | |
| model = SpriteGenerator() | |
| trained_model = train_model( | |
| model, | |
| train_loader, | |
| val_loader, | |
| epochs=20 | |
| ) | |
| print("Modello addestrato con successo!") |