Spaces:
Paused
Paused
Daniel Gil-U Fuhge commited on
Commit ·
f7da327
1
Parent(s): 91b5220
update to new temperature approach
Browse files
AnimationTransformer.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import math
|
| 2 |
import time
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
|
@@ -170,33 +171,39 @@ def fit(model, optimizer, loss_function, train_dataloader, val_dataloader, epoch
|
|
| 170 |
return train_loss_list, validation_loss_list
|
| 171 |
|
| 172 |
|
| 173 |
-
def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=32, eos_scaling=1, backpropagate=False, showResult= True):
|
| 174 |
if backpropagate:
|
| 175 |
model.train()
|
| 176 |
else:
|
| 177 |
-
model.eval()
|
| 178 |
|
| 179 |
source_sequence = source_sequence.float().to(device)
|
| 180 |
y_input = torch.unsqueeze(sos_token, dim=0).float().to(device)
|
| 181 |
-
|
| 182 |
i = 0
|
| 183 |
while i < max_length:
|
| 184 |
# Get source mask
|
|
|
|
| 185 |
prediction = model(source_sequence.unsqueeze(0), y_input.unsqueeze(0), # un-squeeze for batch
|
| 186 |
# tgt_mask=get_tgt_mask(y_input.size(0)).to(device),
|
| 187 |
src_key_padding_mask=create_pad_mask(source_sequence.unsqueeze(0)).to(device))
|
| 188 |
-
|
| 189 |
next_embedding = prediction[0, -1, :] # prediction on last token
|
|
|
|
| 190 |
pred_deep_svg, pred_type, pred_parameters = dataset_helper.unpack_embedding(next_embedding, dim=0)
|
| 191 |
#print(pred_deep_svg, pred_type, pred_parameters)
|
| 192 |
pred_deep_svg, pred_type, pred_parameters = pred_deep_svg.to(device), pred_type.to(device), pred_parameters.to(
|
| 193 |
device)
|
| 194 |
|
|
|
|
|
|
|
| 195 |
# === TYPE ===
|
| 196 |
# Apply Softmax
|
| 197 |
type_softmax = torch.softmax(pred_type, dim=0)
|
| 198 |
type_softmax[0] = type_softmax[0] * eos_scaling # Reduce EOS
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
# Break if EOS is most likely
|
| 202 |
if animation_type == 0:
|
|
@@ -222,6 +229,7 @@ def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=
|
|
| 222 |
|
| 223 |
# === SEQUENCE ===
|
| 224 |
y_new = torch.concat([closest_token[:-26], pred_type.to(device), pred_parameters], dim=0)
|
|
|
|
| 225 |
y_input = torch.cat((y_input, y_new.unsqueeze(0)), dim=0)
|
| 226 |
|
| 227 |
# === INFO PRINT ===
|
|
|
|
| 1 |
import math
|
| 2 |
import time
|
| 3 |
+
import random
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
|
|
| 171 |
return train_loss_list, validation_loss_list
|
| 172 |
|
| 173 |
|
| 174 |
+
def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=32, eos_scaling=1, backpropagate=False, showResult= True, temperature=1):
|
| 175 |
if backpropagate:
|
| 176 |
model.train()
|
| 177 |
else:
|
| 178 |
+
model.eval()
|
| 179 |
|
| 180 |
source_sequence = source_sequence.float().to(device)
|
| 181 |
y_input = torch.unsqueeze(sos_token, dim=0).float().to(device)
|
| 182 |
+
#print(source_sequence, source_sequence.unsqueeze(0))
|
| 183 |
i = 0
|
| 184 |
while i < max_length:
|
| 185 |
# Get source mask
|
| 186 |
+
#print(y_input, y_input.unsqueeze(0))
|
| 187 |
prediction = model(source_sequence.unsqueeze(0), y_input.unsqueeze(0), # un-squeeze for batch
|
| 188 |
# tgt_mask=get_tgt_mask(y_input.size(0)).to(device),
|
| 189 |
src_key_padding_mask=create_pad_mask(source_sequence.unsqueeze(0)).to(device))
|
|
|
|
| 190 |
next_embedding = prediction[0, -1, :] # prediction on last token
|
| 191 |
+
|
| 192 |
pred_deep_svg, pred_type, pred_parameters = dataset_helper.unpack_embedding(next_embedding, dim=0)
|
| 193 |
#print(pred_deep_svg, pred_type, pred_parameters)
|
| 194 |
pred_deep_svg, pred_type, pred_parameters = pred_deep_svg.to(device), pred_type.to(device), pred_parameters.to(
|
| 195 |
device)
|
| 196 |
|
| 197 |
+
pred_type = pred_type / temperature
|
| 198 |
+
|
| 199 |
# === TYPE ===
|
| 200 |
# Apply Softmax
|
| 201 |
type_softmax = torch.softmax(pred_type, dim=0)
|
| 202 |
type_softmax[0] = type_softmax[0] * eos_scaling # Reduce EOS
|
| 203 |
+
|
| 204 |
+
indices = torch.argsort(type_softmax, descending=True)
|
| 205 |
+
animation_type = random.choice(indices[:3])
|
| 206 |
+
#animation_type = torch.argmax(type_softmax, dim=0)
|
| 207 |
|
| 208 |
# Break if EOS is most likely
|
| 209 |
if animation_type == 0:
|
|
|
|
| 229 |
|
| 230 |
# === SEQUENCE ===
|
| 231 |
y_new = torch.concat([closest_token[:-26], pred_type.to(device), pred_parameters], dim=0)
|
| 232 |
+
#y_new = torch.concat([pred_deep_svg, pred_type.to(device), pred_parameters], dim=0)
|
| 233 |
y_input = torch.cat((y_input, y_new.unsqueeze(0)), dim=0)
|
| 234 |
|
| 235 |
# === INFO PRINT ===
|
animationPipeline.py
CHANGED
|
@@ -15,10 +15,10 @@ def animateLogo(path : str, targetPath : str):
|
|
| 15 |
except Exception as e:
|
| 16 |
print(f"An error occurred: {e}")
|
| 17 |
#transformer
|
| 18 |
-
NUM_HEADS =
|
| 19 |
-
NUM_ENCODER_LAYERS =
|
| 20 |
-
NUM_DECODER_LAYERS =
|
| 21 |
-
DROPOUT=0.
|
| 22 |
# CONSTANTS
|
| 23 |
FEATURE_DIM = 282
|
| 24 |
|
|
@@ -34,7 +34,7 @@ def animateLogo(path : str, targetPath : str):
|
|
| 34 |
use_positional_encoder=True
|
| 35 |
).to(device)
|
| 36 |
|
| 37 |
-
model.load_state_dict(torch.load("models/
|
| 38 |
|
| 39 |
df = compute_embedding(path, "models/deepSVG_hierarchical_ordered.pth.tar")
|
| 40 |
df = df.drop("animation_id", axis=1)
|
|
@@ -46,7 +46,7 @@ def animateLogo(path : str, targetPath : str):
|
|
| 46 |
|
| 47 |
sos_token = torch.zeros(282)
|
| 48 |
sos_token[256] = 1
|
| 49 |
-
result = predict(model, inp, sos_token=sos_token, device=device, max_length=inp.shape[0], eos_scaling=
|
| 50 |
result = pd.DataFrame(result[1:, -26:].cpu().detach().numpy())
|
| 51 |
result = pd.DataFrame({"model_output" : [row.tolist() for index, row in result.iterrows()]})
|
| 52 |
result["animation_id"] = range(len(result))
|
|
|
|
| 15 |
except Exception as e:
|
| 16 |
print(f"An error occurred: {e}")
|
| 17 |
#transformer
|
| 18 |
+
NUM_HEADS = 47 # Dividers of 282: {1, 2, 3, 6, 47, 94, 141, 282}
|
| 19 |
+
NUM_ENCODER_LAYERS = 6
|
| 20 |
+
NUM_DECODER_LAYERS = 4
|
| 21 |
+
DROPOUT=0.21
|
| 22 |
# CONSTANTS
|
| 23 |
FEATURE_DIM = 282
|
| 24 |
|
|
|
|
| 34 |
use_positional_encoder=True
|
| 35 |
).to(device)
|
| 36 |
|
| 37 |
+
model.load_state_dict(torch.load("models/animation_transformer2.pth", map_location=torch.device('cpu')), strict=False)
|
| 38 |
|
| 39 |
df = compute_embedding(path, "models/deepSVG_hierarchical_ordered.pth.tar")
|
| 40 |
df = df.drop("animation_id", axis=1)
|
|
|
|
| 46 |
|
| 47 |
sos_token = torch.zeros(282)
|
| 48 |
sos_token[256] = 1
|
| 49 |
+
result = predict(model, inp, sos_token=sos_token, device=device, max_length=inp.shape[0], eos_scaling=0.5, temperature=100)
|
| 50 |
result = pd.DataFrame(result[1:, -26:].cpu().detach().numpy())
|
| 51 |
result = pd.DataFrame({"model_output" : [row.tolist() for index, row in result.iterrows()]})
|
| 52 |
result["animation_id"] = range(len(result))
|
models/{animation_transformer.pth → animation_transformer2.pth}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e63638f545c6f925a1a6d31578d507834de8ed30b71db2a0762c86859c597c44
|
| 3 |
+
size 69927679
|