Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -50,6 +50,16 @@ model = ConvVAE()
|
|
| 50 |
model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
|
| 51 |
model.eval()
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def generate_map(seed: int = None):
|
| 54 |
model.eval()
|
| 55 |
if seed is None:
|
|
|
|
| 50 |
model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
|
| 51 |
model.eval()
|
| 52 |
|
| 53 |
+
# Sampling
|
| 54 |
+
def sample_with_temperature(probs, temperature=1.2):
|
| 55 |
+
logits = torch.log(probs + 1e-8) / temperature
|
| 56 |
+
scaled_probs = torch.softmax(logits, dim=1)
|
| 57 |
+
batch, channels, height, width = scaled_probs.shape
|
| 58 |
+
scaled_probs = scaled_probs.permute(0, 2, 3, 1).contiguous().view(-1, channels)
|
| 59 |
+
sampled = torch.multinomial(scaled_probs, num_samples=1)
|
| 60 |
+
sampled = sampled.view(batch, height, width)
|
| 61 |
+
return sampled
|
| 62 |
+
|
| 63 |
def generate_map(seed: int = None):
|
| 64 |
model.eval()
|
| 65 |
if seed is None:
|