Spaces:
Runtime error
Runtime error
context
Browse files
app.py
CHANGED
|
@@ -161,9 +161,38 @@ def sample_ddim(n_sample, n=20):
|
|
| 161 |
intermediate = np.stack(intermediate)
|
| 162 |
return samples, intermediate
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
def greet(input):
|
| 165 |
steps = int(input)
|
| 166 |
-
samples, intermediate = sample_ddim(32, n=steps)
|
|
|
|
|
|
|
| 167 |
#response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
|
| 168 |
#response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
|
| 169 |
#response = im.fromarray(intermediate[24][0][1]).convert("RGB")
|
|
|
|
| 161 |
intermediate = np.stack(intermediate)
|
| 162 |
return samples, intermediate
|
| 163 |
|
| 164 |
+
# load in model weights and set to eval mode
|
| 165 |
+
nn_model.load_state_dict(torch.load(f"{save_dir}/context_model_31.pth", map_location=device))
|
| 166 |
+
nn_model.eval()
|
| 167 |
+
print("Loaded in Context Model")
|
| 168 |
+
|
| 169 |
+
# fast sampling algorithm with context
|
| 170 |
+
@torch.no_grad()
|
| 171 |
+
def sample_ddim_context(n_sample, context, n=20):
|
| 172 |
+
# x_T ~ N(0, 1), sample initial noise
|
| 173 |
+
samples = torch.randn(n_sample, 3, height, height).to(device)
|
| 174 |
+
|
| 175 |
+
# array to keep track of generated steps for plotting
|
| 176 |
+
intermediate = []
|
| 177 |
+
step_size = timesteps // n
|
| 178 |
+
for i in range(timesteps, 0, -step_size):
|
| 179 |
+
print(f'sampling timestep {i:3d}', end='\r')
|
| 180 |
+
|
| 181 |
+
# reshape time tensor
|
| 182 |
+
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
|
| 183 |
+
|
| 184 |
+
eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t)
|
| 185 |
+
samples = denoise_ddim(samples, i, i - step_size, eps)
|
| 186 |
+
intermediate.append(samples.detach().cpu().numpy())
|
| 187 |
+
|
| 188 |
+
intermediate = np.stack(intermediate)
|
| 189 |
+
return samples, intermediate
|
| 190 |
+
|
| 191 |
def greet(input):
|
| 192 |
steps = int(input)
|
| 193 |
+
#samples, intermediate = sample_ddim(32, n=steps)
|
| 194 |
+
ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
| 195 |
+
samples, intermediate = sample_ddim_context(32, ctx)
|
| 196 |
#response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
|
| 197 |
#response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
|
| 198 |
#response = im.fromarray(intermediate[24][0][1]).convert("RGB")
|