Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import os | |
| from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration | |
| from huggingface_hub import hf_hub_download | |
| import torch.nn as nn | |
| class SpriteGenerator(nn.Module): | |
| def __init__(self, text_encoder_name="t5-base", latent_dim=512): | |
| super(SpriteGenerator, self).__init__() | |
| # Text encoder (T5 with lm_head) | |
| self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder_name) | |
| for param in self.text_encoder.parameters(): | |
| param.requires_grad = False | |
| # Proiezione dal testo al latent space | |
| self.text_projection = nn.Sequential( | |
| nn.Linear(768, latent_dim), | |
| nn.LeakyReLU(0.2), | |
| nn.Linear(latent_dim, latent_dim) | |
| ) | |
| # Generator | |
| self.generator = nn.Sequential( | |
| # Input: latent_dim x 1 x 1 -> 512 x 4 x 4 | |
| nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(True), | |
| # 512 x 4 x 4 -> 256 x 8 x 8 | |
| nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(True), | |
| # 256 x 8 x 8 -> 128 x 16 x 16 | |
| nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(True), | |
| # 128 x 16 x 16 -> 64 x 32 x 32 | |
| nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(True), | |
| # 64 x 32 x 32 -> 32 x 64 x 64 | |
| nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(True), | |
| # 32 x 64 x 64 -> 16 x 128 x 128 | |
| nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(16), | |
| nn.ReLU(True), | |
| # 16 x 128 x 128 -> 3 x 256 x 256 | |
| nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False), | |
| ) | |
| # Frame interpolator | |
| self.frame_interpolator = nn.Sequential( | |
| nn.Linear(latent_dim + 1, latent_dim), | |
| nn.LeakyReLU(0.2), | |
| nn.Linear(latent_dim, latent_dim), | |
| nn.LeakyReLU(0.2) | |
| ) | |
| def forward(self, input_ids, attention_mask, num_frames=1): | |
| batch_size = input_ids.shape[0] | |
| # Encode text usando il T5 completo | |
| text_outputs = self.text_encoder.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True | |
| ) | |
| # Get text features | |
| text_features = text_outputs.last_hidden_state.mean(dim=1) | |
| # Project to latent space | |
| latent_vector = self.text_projection(text_features) | |
| # Generate multiple frames if needed | |
| all_frames = [] | |
| for frame_idx in range(max(num_frames.max().item(), 1)): | |
| frame_info = torch.ones((batch_size, 1), device=latent_vector.device) * frame_idx / max(num_frames.max().item(), 1) | |
| # Combine latent vector with frame info | |
| frame_latent = self.frame_interpolator( | |
| torch.cat([latent_vector, frame_info], dim=1) | |
| ) | |
| # Generate frame | |
| frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3) | |
| frame = self.generator(frame_latent_reshaped) | |
| frame = torch.tanh(frame) | |
| all_frames.append(frame) | |
| # Stack all frames | |
| sprites = torch.stack(all_frames, dim=1) | |
| return sprites | |
| def initialize_model(): | |
| print("Inizializzazione del modello...") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = SpriteGenerator() | |
| try: | |
| # Scarica il modello da Hugging Face Hub | |
| model_path = hf_hub_download( | |
| repo_id="Lod34/Animator2D-v2", | |
| filename="pytorch_model.bin", | |
| repo_type="model" | |
| ) | |
| # Carica il modello | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| print(f"Modello caricato con successo su {device}!") | |
| return model, device | |
| except Exception as e: | |
| print(f"Errore nel caricamento del modello: {str(e)}") | |
| raise | |
| def generate_sprite(prompt, num_frames=8): | |
| try: | |
| # Usa il modello e il device globali | |
| global model, device, tokenizer | |
| # Tokenizza il testo | |
| tokens = tokenizer(prompt, return_tensors="pt", padding=True) | |
| tokens = {k: v.to(device) for k, v in tokens.items()} | |
| # Genera l'immagine | |
| with torch.no_grad(): | |
| frames = model( | |
| input_ids=tokens["input_ids"], | |
| attention_mask=tokens["attention_mask"], | |
| num_frames=torch.tensor([num_frames], device=device) | |
| ) | |
| # Converte il tensore in immagine | |
| frames = (frames * 0.5 + 0.5).clamp(0, 1) | |
| frames = frames.cpu().numpy() | |
| # Ritorna il primo frame come esempio | |
| frame = frames[0, 0] # Prende il primo frame del batch | |
| frame = (frame * 255).astype('uint8').transpose(1, 2, 0) | |
| return Image.fromarray(frame) | |
| except Exception as e: | |
| print(f"Errore nella generazione: {str(e)}") | |
| raise | |
| # Inizializzazione globale | |
| print("Caricamento del modello e configurazione dell'interfaccia...") | |
| try: | |
| # Inizializzazione del modello e del tokenizer | |
| model, device = initialize_model() | |
| tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
| # Configurazione dell'interfaccia Gradio | |
| interface = gr.Interface( | |
| fn=generate_sprite, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Descrivi lo sprite che vuoi generare", | |
| placeholder="Esempio: un personaggio pixel art che cammina" | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=16, | |
| value=8, | |
| step=1, | |
| label="Numero di frame", | |
| info="Più frame = animazione più fluida ma generazione più lenta" | |
| ) | |
| ], | |
| outputs=gr.Image(label="Sprite generato"), | |
| title="🎮 Animator2D-v2 Sprite Generator", | |
| description=""" | |
| ## Generatore di Sprite Animati | |
| Questo strumento genera sprite pixel art da descrizioni testuali. | |
| ### Come usare: | |
| 1. Inserisci una descrizione dello sprite che vuoi generare | |
| 2. Regola il numero di frame desiderati | |
| 3. Clicca su Submit e attendi la generazione | |
| ### Tips: | |
| - Sii specifico nella descrizione | |
| - Prova diversi numeri di frame per risultati diversi | |
| - Le descrizioni in inglese potrebbero funzionare meglio | |
| """, | |
| article=""" | |
| ### Note: | |
| - La generazione può richiedere alcuni secondi | |
| - Vengono mostrati solo i primi frame dell'animazione | |
| - Per risultati migliori, usa descrizioni dettagliate | |
| Creato da [Lod34](https://huggingface.co/Lod34) | |
| """ | |
| ) | |
| # Avvio dell'interfaccia | |
| interface.launch() | |
| except Exception as e: | |
| print(f"Errore nell'inizializzazione dell'applicazione: {str(e)}") | |
| raise |