Lorenzo Adacher commited on
Commit
902125a
·
verified ·
1 Parent(s): 2274519

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -26
app.py CHANGED
@@ -2,17 +2,16 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
  import os
5
- from transformers import AutoTokenizer, AutoModel
6
  from huggingface_hub import hf_hub_download
7
  import torch.nn as nn
8
 
9
- # Definizione del modello
10
  class SpriteGenerator(nn.Module):
11
  def __init__(self, text_encoder_name="t5-base", latent_dim=512):
12
  super(SpriteGenerator, self).__init__()
13
 
14
- # Text encoder (T5)
15
- self.text_encoder = AutoModel.from_pretrained(text_encoder_name)
16
  for param in self.text_encoder.parameters():
17
  param.requires_grad = False
18
 
@@ -23,30 +22,39 @@ class SpriteGenerator(nn.Module):
23
  nn.Linear(latent_dim, latent_dim)
24
  )
25
 
26
- # Generator
27
  self.generator = nn.Sequential(
28
- nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
 
29
  nn.BatchNorm2d(512),
30
- nn.LeakyReLU(0.2, inplace=True),
31
-
32
- nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
33
  nn.BatchNorm2d(256),
34
- nn.LeakyReLU(0.2, inplace=True),
35
-
36
- nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
37
  nn.BatchNorm2d(128),
38
- nn.LeakyReLU(0.2, inplace=True),
39
-
40
- nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
41
  nn.BatchNorm2d(64),
42
- nn.LeakyReLU(0.2, inplace=True),
43
-
44
- nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
45
  nn.BatchNorm2d(32),
46
- nn.LeakyReLU(0.2, inplace=True),
 
 
 
 
 
 
 
 
 
47
 
48
- nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
49
- nn.Tanh()
50
  )
51
 
52
  # Frame interpolator
@@ -60,8 +68,8 @@ class SpriteGenerator(nn.Module):
60
  def forward(self, input_ids, attention_mask, num_frames=1):
61
  batch_size = input_ids.shape[0]
62
 
63
- # Encode text
64
- text_outputs = self.text_encoder(
65
  input_ids=input_ids,
66
  attention_mask=attention_mask,
67
  return_dict=True
@@ -92,7 +100,7 @@ class SpriteGenerator(nn.Module):
92
  sprites = torch.stack(all_frames, dim=1)
93
 
94
  return sprites
95
-
96
  # Costanti
97
  MODEL_ID = "Lod34/Animator2D-v2"
98
  CACHE_DIR = "model_cache"
@@ -242,5 +250,3 @@ def create_interface():
242
 
243
  # Crea l'interfaccia
244
  demo = create_interface()
245
-
246
- # Per Spaces, non usare demo.launch()
 
2
  import torch
3
  from PIL import Image
4
  import os
5
+ from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration
6
  from huggingface_hub import hf_hub_download
7
  import torch.nn as nn
8
 
 
9
  class SpriteGenerator(nn.Module):
10
  def __init__(self, text_encoder_name="t5-base", latent_dim=512):
11
  super(SpriteGenerator, self).__init__()
12
 
13
+ # Text encoder (T5 with lm_head)
14
+ self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder_name)
15
  for param in self.text_encoder.parameters():
16
  param.requires_grad = False
17
 
 
22
  nn.Linear(latent_dim, latent_dim)
23
  )
24
 
25
+ # Generator modificato per corrispondere ai pesi salvati
26
  self.generator = nn.Sequential(
27
+ # Input: latent_dim x 1 x 1
28
+ nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), # -> 512 x 4 x 4
29
  nn.BatchNorm2d(512),
30
+ nn.ReLU(True),
31
+
32
+ nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), # -> 256 x 8 x 8
33
  nn.BatchNorm2d(256),
34
+ nn.ReLU(True),
35
+
36
+ nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), # -> 128 x 16 x 16
37
  nn.BatchNorm2d(128),
38
+ nn.ReLU(True),
39
+
40
+ nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), # -> 64 x 32 x 32
41
  nn.BatchNorm2d(64),
42
+ nn.ReLU(True),
43
+
44
+ nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), # -> 32 x 64 x 64
45
  nn.BatchNorm2d(32),
46
+ nn.ReLU(True),
47
+
48
+ nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False), # -> 16 x 128 x 128
49
+ nn.BatchNorm2d(16),
50
+ nn.ReLU(True),
51
+
52
+ # Layer finale modificato per corrispondere ai pesi
53
+ nn.ConvTranspose2d(16, 16, 4, 2, 1, bias=False), # -> 16 x 256 x 256
54
+ nn.BatchNorm2d(16),
55
+ nn.ReLU(True),
56
 
57
+ nn.Conv2d(16, 3, 3, 1, 1) # Layer di output per RGB
 
58
  )
59
 
60
  # Frame interpolator
 
68
  def forward(self, input_ids, attention_mask, num_frames=1):
69
  batch_size = input_ids.shape[0]
70
 
71
+ # Encode text usando il T5 completo
72
+ text_outputs = self.text_encoder.encoder(
73
  input_ids=input_ids,
74
  attention_mask=attention_mask,
75
  return_dict=True
 
100
  sprites = torch.stack(all_frames, dim=1)
101
 
102
  return sprites
103
+
104
  # Costanti
105
  MODEL_ID = "Lod34/Animator2D-v2"
106
  CACHE_DIR = "model_cache"
 
250
 
251
  # Crea l'interfaccia
252
  demo = create_interface()