Spaces:
No application file
No application file
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import AutoTokenizer, AutoModel | |
| from datasets import load_dataset | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| class SpriteDataset(Dataset): | |
| def __init__(self, dataset_split="train"): | |
| # Load the dataset from HuggingFace | |
| self.dataset = load_dataset("pawkanarek/spraix_1024", split=dataset_split) | |
| self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| # Define image transforms | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((64, 64)), # Resize all sprites to same size | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = self.dataset[idx] | |
| # Process text description | |
| text = f"{item['text']}" # Contains frames, description, action, direction | |
| encoded_text = self.tokenizer( | |
| text, | |
| padding="max_length", | |
| max_length=128, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| # Process image | |
| # The item['image'] is already a PIL Image. Convert it to RGB if it's not already | |
| image = item['image'].convert('RGB') | |
| # Removed Image.fromarray as it's unnecessary | |
| image_tensor = self.transform(image) | |
| return { | |
| 'text_ids': encoded_text['input_ids'].squeeze(), | |
| 'text_mask': encoded_text['attention_mask'].squeeze(), | |
| 'image': image_tensor | |
| } | |
| class TextEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained("bert-base-uncased") | |
| self.linear = nn.Linear(768, 512) # Reduce BERT output dimension | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| return self.linear(outputs.last_hidden_state[:, 0, :]) # Use [CLS] token | |
| class SpriteGenerator(nn.Module): | |
| def __init__(self, latent_dim=512): | |
| super().__init__() | |
| self.generator = nn.Sequential( | |
| # Input: latent_dim x 1 x 1 | |
| nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(True), | |
| # 512 x 4 x 4 | |
| nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(True), | |
| # 256 x 8 x 8 | |
| nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(True), | |
| # 128 x 16 x 16 | |
| nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(True), | |
| # 64 x 32 x 32 | |
| nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False), | |
| nn.Tanh() | |
| # Output: 3 x 64 x 64 | |
| ) | |
| def forward(self, z): | |
| z = z.unsqueeze(-1).unsqueeze(-1) # Add spatial dimensions | |
| return self.generator(z) | |
| class Animator2D(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.text_encoder = TextEncoder() | |
| self.sprite_generator = SpriteGenerator() | |
| def forward(self, input_ids, attention_mask): | |
| text_features = self.text_encoder(input_ids, attention_mask) | |
| generated_sprite = self.sprite_generator(text_features) | |
| return generated_sprite | |
| def train_model(num_epochs=100, batch_size=32, learning_rate=0.0002): | |
| # Initialize dataset and dataloader | |
| dataset = SpriteDataset() | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| # Initialize model and optimizer | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = Animator2D().to(device) | |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.5, 0.999)) | |
| criterion = nn.MSELoss() | |
| # Training loop | |
| for epoch in range(num_epochs): | |
| for batch_idx, batch in enumerate(dataloader): | |
| # Move data to device | |
| text_ids = batch['text_ids'].to(device) | |
| text_mask = batch['text_mask'].to(device) | |
| real_images = batch['image'].to(device) | |
| # Forward pass | |
| generated_images = model(text_ids, text_mask) | |
| # Calculate loss | |
| loss = criterion(generated_images, real_images) | |
| # Backward pass and optimization | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if batch_idx % 100 == 0: | |
| print(f"Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}/{len(dataloader)}] Loss: {loss.item():.4f}") | |
| # Save the trained model | |
| torch.save(model.state_dict(), "animator2d_model.pth") | |
| return model | |
| def generate_sprite_animation(model, num_frames, description, action, direction): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.eval() | |
| # Prepare input text | |
| text = f"{num_frames}-frame sprite animation of: {description}, that: {action}, facing: {direction}" | |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| encoded_text = tokenizer( | |
| text, | |
| padding="max_length", | |
| max_length=128, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| # Generate sprite sheet | |
| with torch.no_grad(): | |
| text_ids = encoded_text['input_ids'].to(device) | |
| text_mask = encoded_text['attention_mask'].to(device) | |
| generated_sprite = model(text_ids, text_mask) | |
| # Convert to image | |
| generated_sprite = generated_sprite.cpu().squeeze(0) | |
| generated_sprite = (generated_sprite + 1) / 2 # Denormalize | |
| generated_sprite = transforms.ToPILImage()(generated_sprite) | |
| return generated_sprite | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Train the model | |
| model = train_model() | |
| # Generate a new sprite animation | |
| test_params = { | |
| "num_frames": 17, | |
| "description": "red-haired hobbit in green cape", | |
| "action": "shoots with slingshot", | |
| "direction": "East" | |
| } | |
| sprite_sheet = generate_sprite_animation( | |
| model, | |
| test_params["num_frames"], | |
| test_params["description"], | |
| test_params["action"], | |
| test_params["direction"] | |
| ) | |
| sprite_sheet.save("generated_sprite.png") |