flappybird1084 commited on
Commit
6caa466
·
1 Parent(s): 027ca38

add initial files

Browse files
README.md CHANGED
@@ -1,13 +1,15 @@
1
- ---
2
- title: Terrain Generation
3
- emoji: 🔥
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.39.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: '3D terrain generator from 2D segmentation mask. '
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
+ Terrain Reconstruction
2
+
3
+ ```bash
4
+ pyenv shell 3.11
5
+ python3 -m venv env
6
+ source env/bin/activate
7
+
8
+ pip install -r requirements.txt
9
+ mkdir -p models/terrain
10
+
11
+ python3 train_heightmap.py
12
+ python3 train_terrain.py
13
+ ```
14
+
15
+ CUDA/MPS advised.
app.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mpl_toolkits.mplot3d import Axes3D
2
+ import matplotlib.pyplot as plt
3
+ import gradio as gr
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ import numpy as np
10
+ import os
11
+ import matplotlib
12
+ import base64
13
+ import tempfile
14
+ import trimesh
15
+ from io import BytesIO
16
+ import io
17
+ # Set the matplotlib backend to 'Agg' for non-interactive plotting in a server environment.
18
+ matplotlib.use('Agg')
19
+
20
+ # Define the DoubleConv and UNet classes exactly as in your notebook
21
+
22
+
23
+ class DoubleConv(nn.Module):
24
+ """(convolution => [BN] => ReLU) * 2"""
25
+
26
+ def __init__(self, in_channels, out_channels):
27
+ super(DoubleConv, self).__init__()
28
+ self.conv = nn.Sequential(
29
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
30
+ nn.ReLU(inplace=True),
31
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
32
+ nn.ReLU(inplace=True)
33
+ )
34
+
35
+ def forward(self, x):
36
+ return self.conv(x)
37
+
38
+
39
+ class UNet(nn.Module):
40
+ def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
41
+ super(UNet, self).__init__()
42
+ self.encoder = nn.ModuleList()
43
+ self.decoder = nn.ModuleList()
44
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
45
+
46
+ # Encoder (Downsampling path)
47
+ for feature in features:
48
+ self.encoder.append(DoubleConv(in_channels, feature))
49
+ in_channels = feature
50
+
51
+ # Bottleneck
52
+ self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
53
+
54
+ # Decoder (Upsampling path)
55
+ for feature in reversed(features):
56
+ self.decoder.append(nn.ConvTranspose2d(
57
+ feature * 2, feature, kernel_size=2, stride=2))
58
+ self.decoder.append(DoubleConv(feature * 2, feature))
59
+
60
+ self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
61
+
62
+ def forward(self, x):
63
+ skip_connections = []
64
+
65
+ # Encode
66
+ for layer in self.encoder:
67
+ x = layer(x)
68
+ skip_connections.append(x)
69
+ x = self.pool(x)
70
+
71
+ # Bottleneck
72
+ x = self.bottleneck(x)
73
+ skip_connections = skip_connections[::-1]
74
+
75
+ # Decode
76
+ for idx in range(0, len(self.decoder), 2):
77
+ x = self.decoder[idx](x) # Upsampling conv
78
+ skip_connection = skip_connections[idx // 2]
79
+ # Resize if necessary
80
+ if x.shape != skip_connection.shape:
81
+ x = F.interpolate(
82
+ x, size=skip_connection.shape[2:], mode='bilinear', align_corners=True)
83
+ # Concatenate skip connection
84
+ concat_skip = torch.cat((skip_connection, x), dim=1)
85
+ x = self.decoder[idx + 1](concat_skip) # DoubleConv
86
+
87
+ return self.final_conv(x)
88
+
89
+
90
+ # Helper function to convert PIL image to base64 data URI
91
+
92
+
93
+ def generate_mesh_from_images(heightmap_img, texture_img, max_height=100.0):
94
+ """
95
+ Convert heightmap (PIL.Image) and texture map (PIL.Image) into 3D mesh data.
96
+
97
+ Args:
98
+ heightmap_img (PIL.Image): Grayscale image for heightmap.
99
+ texture_img (PIL.Image): Texture image (color) to map with UV coords.
100
+ max_height (float): Maximum elevation represented in the mesh.
101
+
102
+ Returns:
103
+ dict: {
104
+ 'vertices': List of (x, y, z) tuples,
105
+ 'uvs': List of (u, v) tuples,
106
+ 'faces': List of (v0, v1, v2) tuples (index-based),
107
+ 'dimensions': (width, height)
108
+ }
109
+ """
110
+ # Ensure both images are the same size
111
+ if heightmap_img.size != texture_img.size:
112
+ raise ValueError("Heightmap and texture must be the same dimensions.")
113
+
114
+ width, height = heightmap_img.size
115
+
116
+ # Convert heightmap to NumPy array and normalize
117
+ height_data = np.asarray(heightmap_img.convert('L'),
118
+ dtype=np.float32) / 255.0
119
+ height_data *= max_height
120
+
121
+ vertices = []
122
+ uvs = []
123
+ faces = []
124
+
125
+ for y in range(height):
126
+ for x in range(width):
127
+ z = height_data[y][x]
128
+ vertices.append((x, z, y)) # World position
129
+ uvs.append((x / (width - 1), y / (height - 1))) # UV coords
130
+
131
+ for y in range(height - 1):
132
+ for x in range(width - 1):
133
+ i = y * width + x
134
+ i_right = i + 1
135
+ i_bottom = i + width
136
+ i_diag = i_bottom + 1
137
+
138
+ # First triangle
139
+ faces.append((i, i_bottom, i_right))
140
+
141
+ # Second triangle
142
+ faces.append((i_right, i_bottom, i_diag))
143
+
144
+ return {
145
+ 'vertices': vertices,
146
+ 'uvs': uvs,
147
+ 'faces': faces,
148
+ 'dimensions': (width, height)
149
+ }
150
+
151
+
152
+ def mesh_to_obj_string(mesh_data):
153
+ vertices = mesh_data['vertices']
154
+ uvs = mesh_data['uvs']
155
+ faces = mesh_data['faces']
156
+
157
+ lines = []
158
+
159
+ # Write vertices
160
+ for v in vertices:
161
+ lines.append(f"v {v[0]:.6f} {v[1]:.6f} {v[2]:.6f}")
162
+
163
+ # Write UVs (texture coordinates)
164
+ for uv in uvs:
165
+ # flip V for OBJ format
166
+ lines.append(f"vt {uv[0]:.6f} {1.0 - uv[1]:.6f}")
167
+
168
+ # Write faces (referencing vertex and UV indices, 1-based)
169
+ for f in faces:
170
+ # OBJ face format: f v1/vt1 v2/vt2 v3/vt3
171
+ v1, v2, v3 = f
172
+ lines.append(f"f {v1+1}/{v1+1} {v2+1}/{v2+1} {v3+1}/{v3+1}")
173
+
174
+ # Join into OBJ text
175
+ obj_text = '\n'.join(lines)
176
+ return obj_text
177
+
178
+ # def mesh_to_obj_file(mesh_data):
179
+ # obj_str = mesh_to_obj_string(mesh_data)
180
+ # tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".obj", mode="w")
181
+ # tmp_file.write(obj_str)
182
+ # tmp_file.close()
183
+ # print(tmp_file.name)
184
+ # # return tmp_file.name # Return file path as string
185
+
186
+
187
+ def mesh_to_obj_file(mesh_data, texture_img):
188
+ obj_str = mesh_to_obj_string(mesh_data)
189
+
190
+ # Create a temporary folder to hold all files
191
+ temp_dir = tempfile.mkdtemp()
192
+
193
+ obj_path = os.path.join(temp_dir, "model.obj")
194
+ mtl_path = os.path.join(temp_dir, "model.mtl")
195
+ texture_path = os.path.join(temp_dir, "texture.png")
196
+
197
+ # Save texture image
198
+ texture_img.save(texture_path)
199
+
200
+ # Write MTL file
201
+ with open(mtl_path, 'w') as f:
202
+ f.write(
203
+ "newmtl material0\n"
204
+ "Ka 1.000 1.000 1.000\n"
205
+ "Kd 1.000 1.000 1.000\n"
206
+ "Ks 0.000 0.000 0.000\n"
207
+ "d 1.0\n"
208
+ "illum 2\n"
209
+ "map_Kd texture.png\n"
210
+ )
211
+
212
+ # Write OBJ file with reference to MTL
213
+ with open(obj_path, 'w') as f:
214
+ f.write("mtllib model.mtl\n")
215
+ f.write("usemtl material0\n")
216
+ f.write(obj_str)
217
+
218
+ return obj_path # Only return OBJ path; Gradio Model3D will find .mtl and texture if in same folder
219
+
220
+ # def render_3d_model(heightmap_img, texture_img):
221
+ # mesh = generate_mesh_from_images(heightmap_img, texture_img)
222
+ # obj_file = mesh_to_obj_file(mesh)
223
+ # return obj_file
224
+
225
+
226
+ def render_3d_model(heightmap_img, texture_img):
227
+ mesh = generate_mesh_from_images(heightmap_img, texture_img)
228
+ obj_file_path = mesh_to_obj_file(mesh, texture_img)
229
+ return obj_file_path # path to .obj file with full material and texture
230
+
231
+ # def render_3d_model_glb(heightmap_img, texture_img, max_height=100.0):
232
+ # mesh_data = generate_mesh_from_images(heightmap_img, texture_img, max_height)
233
+
234
+ # vertices = np.array(mesh_data['vertices'], dtype=np.float32)
235
+ # faces = np.array(mesh_data['faces'], dtype=np.int64)
236
+ # uvs = np.array(mesh_data['uvs'], dtype=np.float32)
237
+
238
+ # # Convert heightmap + uvs into a mesh
239
+ # mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
240
+ # mesh.visual = trimesh.visual.TextureVisuals(uv=uvs)
241
+
242
+ # # Save the texture to a temporary file
243
+ # temp_folder = tempfile.mkdtemp()
244
+ # texture_path = os.path.join(temp_folder, "diffuse.png")
245
+ # texture_img.save(texture_path)
246
+
247
+ # material = trimesh.visual.material.PBRMaterial(
248
+ # baseColorTexture=trimesh.visual.texture.TextureVisuals(image=texture_path)
249
+ # )
250
+
251
+ # # Apply material (optional: set mesh.visual with material directly)
252
+ # mesh.visual.material = material
253
+
254
+ # # Assemble into a scene
255
+ # scene = trimesh.Scene()
256
+ # scene.add_geometry(mesh)
257
+
258
+ # # Export to glb
259
+ # glb_path = os.path.join(temp_folder, "terrain.glb")
260
+ # scene.export(glb_path, file_type='glb')
261
+ # return glb_path
262
+
263
+
264
+ def render_3d_model_glb(heightmap_img, texture_img, max_height=70.0):
265
+ mesh_data = generate_mesh_from_images(
266
+ heightmap_img, texture_img, max_height)
267
+ texture_img_flipped = texture_img.transpose(Image.FLIP_TOP_BOTTOM)
268
+
269
+ texture_img = texture_img_flipped
270
+
271
+ vertices = np.array(mesh_data['vertices'], dtype=np.float32)
272
+ faces = np.array(mesh_data['faces'], dtype=np.int64)
273
+ uvs = np.array(mesh_data['uvs'], dtype=np.float32)
274
+
275
+ # Create Trimesh object
276
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
277
+
278
+ # Assign UV coordinates
279
+ mesh.visual = trimesh.visual.TextureVisuals(uv=uvs)
280
+
281
+ # Save texture to PNG in memory
282
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tex_file:
283
+ texture_img.save(tex_file.name)
284
+ tex_filepath = tex_file.name
285
+
286
+ # Apply texture using visual.material
287
+ mesh.visual.material.image = texture_img # PIL Image object
288
+
289
+ # Build scene
290
+ scene = trimesh.Scene()
291
+ scene.add_geometry(mesh)
292
+
293
+ # Write GLB
294
+ glb_path = os.path.join(tempfile.mkdtemp(), "terrain.glb")
295
+ scene.export(glb_path, file_type='glb')
296
+
297
+ return glb_path
298
+
299
+
300
+ # --- Model and Presets Loading ---
301
+ script_dir = os.path.dirname(os.path.abspath(__file__))
302
+ heightmap_model_path = os.path.join(
303
+ script_dir, './models/terrain/turbo_heightmap_unet_model.pth')
304
+ terrain_model_path = os.path.join(
305
+ script_dir, './models/terrain/turbo_terrain_unet_model.pth')
306
+ presets_folder_path = os.path.join(script_dir, './presets')
307
+
308
+
309
+ # device = torch.device("cpu")
310
+ # device = torch.device("mps")
311
+ # device = torch.device("mps" if torch.backends.mps.is_available(
312
+ # ) else "cuda" if torch.cuda.is_available() else "cpu")
313
+ if torch.backends.mps.is_available():
314
+ device = torch.device("mps")
315
+
316
+ elif torch.cuda.is_available():
317
+ device = torch.device("cuda")
318
+
319
+ else:
320
+ device = torch.device("cpu")
321
+
322
+ # Initialize models with the correct architecture
323
+ heightmap_gen_model = UNet(in_channels=3, out_channels=1, features=[
324
+ 64, 128, 256, 512, 1024]).to(device)
325
+ terrain_gen_model = UNet(in_channels=3, out_channels=3).to(device)
326
+
327
+ try:
328
+ print(f"Attempting to load heightmap model from: {heightmap_model_path}")
329
+ heightmap_gen_model.load_state_dict(torch.load(
330
+ heightmap_model_path, map_location=device))
331
+ print(f"Attempting to load terrain model from: {terrain_model_path}")
332
+ terrain_gen_model.load_state_dict(torch.load(
333
+ terrain_model_path, map_location=device))
334
+ print("--- Models loaded successfully. ---")
335
+ except Exception as e:
336
+ print(f"FATAL: Could not load models. Error: {e}")
337
+ exit()
338
+
339
+ # Load preset image paths
340
+ example_paths = []
341
+ if os.path.exists(presets_folder_path):
342
+ for filename in os.listdir(presets_folder_path):
343
+ if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
344
+ example_paths.append(os.path.join(presets_folder_path, filename))
345
+ print(f"Found {len(example_paths)} preset images in {presets_folder_path}")
346
+ else:
347
+ # print(f"WARNING: Presets folder not found at {
348
+ # presets_folder_path}. No examples will be loaded.")
349
+ print("no presets found!! oh noes")
350
+
351
+
352
+ # Define the image transformation pipeline
353
+ transform_pipeline = transforms.Compose([
354
+ transforms.Resize((256, 256)),
355
+ transforms.ToTensor(),
356
+ ])
357
+
358
+
359
+ def generate_3d_plot(heightmap_np, terrain_np, elev, azim):
360
+ """
361
+ Generates a 3D surface plot from a heightmap and a terrain color map.
362
+ """
363
+ heightmap_gray = heightmap_np.squeeze()
364
+
365
+ # Prepare for 3D plotting
366
+ rows, cols = heightmap_gray.shape
367
+ X, Y = np.meshgrid(np.arange(cols), np.arange(rows))
368
+ Z = heightmap_gray.astype(np.float32)
369
+
370
+ # Normalize terrain colors for facecolors
371
+ normal_map_facecolors = terrain_np / 255.0
372
+
373
+ # Create 3D plot
374
+ fig = plt.figure(figsize=(8, 6))
375
+ ax = fig.add_subplot(111, projection='3d')
376
+ # [X, Y, Z] ratio; make Z axis 30% the scale of X/Y
377
+ ax.set_box_aspect([1, 1, 0.3])
378
+
379
+ # Plot the surface with a stride for performance
380
+ # ax.plot_surface(X, Y, Z, facecolors=normal_map_facecolors, rstride=4, cstride=4, linewidth=0, antialiased=False)
381
+ ax.plot_surface(X, Y, Z, facecolors=normal_map_facecolors,
382
+ rstride=2, cstride=2, linewidth=0, antialiased=False)
383
+
384
+ # Set view and labels using slider values
385
+ ax.view_init(elev=elev, azim=azim)
386
+ ax.set_xlabel('X')
387
+ ax.set_ylabel('Y')
388
+ ax.set_zlabel('Z (Elevation)')
389
+ ax.set_title("3D Rendered Terrain")
390
+
391
+ plt.tight_layout()
392
+ return fig
393
+
394
+
395
+ def gaussian_blur(tensor, kernel_size=5, sigma=1.0):
396
+ # Create 1D Gaussian kernel
397
+ def get_gaussian_kernel1d(k, s):
398
+ x = torch.arange(-k//2 + 1., k//2 + 1.)
399
+ kernel = torch.exp(-x**2 / (2*s**2))
400
+ kernel /= kernel.sum()
401
+ return kernel
402
+
403
+ kernel_1d = get_gaussian_kernel1d(kernel_size, sigma).to(tensor.device)
404
+ kernel_2d = torch.outer(kernel_1d, kernel_1d)
405
+
406
+ # Expand to match conv2d weight shape: [out_channels, in_channels, H, W]
407
+ c = tensor.shape[1]
408
+ weight = kernel_2d.expand(c, 1, kernel_size, kernel_size)
409
+
410
+ # Apply padding so spatial dims are preserved
411
+ padding = kernel_size // 2
412
+ blurred = F.conv2d(tensor, weight, padding=padding, groups=c)
413
+ return blurred
414
+
415
+
416
+ def predict(input_image_pil, elevation, azimuth):
417
+ """
418
+ Takes a single input image and view angles, generates heightmap
419
+ and terrain, and creates a 3D plot.
420
+ """
421
+ if input_image_pil is None:
422
+ # Return blank outputs if no image is provided
423
+ blank_image = Image.new('RGB', (256, 256), 'white')
424
+ blank_plot = plt.figure()
425
+ plt.plot([])
426
+ return blank_image, blank_image, blank_plot
427
+ # threejs_html = generate_threejs_html(heightmap_image, terrain_image)
428
+ # return heightmap_image, terrain_image, plot_3d, threejs_html
429
+
430
+ # Ensure it's in RGB format
431
+ input_image_pil = input_image_pil.convert("RGB")
432
+
433
+ input_tensor = transform_pipeline(input_image_pil).unsqueeze(0).to(device)
434
+
435
+ with torch.no_grad():
436
+ heightmap_gen_model.eval()
437
+ terrain_gen_model.eval()
438
+ generated_heightmap_tensor = heightmap_gen_model(input_tensor)
439
+ # apply gaussian blur on hm tensor
440
+ generated_heightmap_tensor = gaussian_blur(
441
+ generated_heightmap_tensor, kernel_size=5, sigma=1.2)
442
+
443
+ generated_terrain_tensor = terrain_gen_model(input_tensor)
444
+ generated_terrain_tensor = gaussian_blur(
445
+ generated_terrain_tensor, kernel_size=5, sigma=1.1)
446
+
447
+ # Post-process for 2D image outputs
448
+ heightmap_np = generated_heightmap_tensor.squeeze(
449
+ 0).cpu().permute(1, 2, 0).numpy()
450
+ terrain_np = generated_terrain_tensor.squeeze(
451
+ 0).cpu().permute(1, 2, 0).numpy()
452
+
453
+ heightmap_np_viz = (heightmap_np - heightmap_np.min()) / \
454
+ (heightmap_np.max() - heightmap_np.min())
455
+ terrain_np_viz = (terrain_np - terrain_np.min()) / \
456
+ (terrain_np.max() - terrain_np.min())
457
+
458
+ heightmap_image = Image.fromarray(
459
+ (heightmap_np_viz * 255).astype(np.uint8).squeeze(), 'L')
460
+ terrain_image = Image.fromarray((terrain_np_viz * 255).astype(np.uint8))
461
+
462
+ # Generate the 3D plot using the numpy arrays and slider values
463
+ plot_3d = generate_3d_plot(
464
+ heightmap_np_viz, (terrain_np_viz * 255).astype(np.uint8), elevation, azimuth)
465
+
466
+ # Close the figure to free up memory
467
+ plt.close(plot_3d)
468
+
469
+ # threejs_html = generate_threejs_html(heightmap_image, terrain_image)
470
+ # threejs_html = generate_3d_terrain(heightmap_image, terrain_image)
471
+ # object_3d=render_3d_model(heightmap_image, terrain_image)
472
+ object_3d = render_3d_model_glb(heightmap_image, terrain_image)
473
+
474
+ return heightmap_image, terrain_image, plot_3d, object_3d
475
+
476
+
477
+ # Create the Gradio Interface
478
+ with gr.Blocks() as iface:
479
+ gr.Markdown("# 2D and 3D Terrain Generator")
480
+ gr.Markdown("Upload, draw, or choose a preset segmentation map to generate a 2D heightmap, a 2D terrain image, and a 3D rendered terrain.")
481
+
482
+ with gr.Row():
483
+ with gr.Column():
484
+ with gr.Tabs():
485
+ with gr.Tab("Upload & Presets"):
486
+ input_img_upload = gr.Image(
487
+ type="pil", label="Input Segmentation Map")
488
+ if example_paths:
489
+ gr.Examples(
490
+ examples=example_paths,
491
+ inputs=input_img_upload,
492
+ label="Preset Segmentation Maps"
493
+ )
494
+ with gr.Tab("Draw"):
495
+ terrain_colors = [
496
+ "#118DD7", # Water 💧
497
+ "#E1E39B", # Grassland 🌾
498
+ "#7FAD7B", # Forest 🌲
499
+ "#B97A57", # Hills ⛰️
500
+ "#E6C8B5", # Desert 🏜️
501
+ "#969696", # Mountain 🏔️
502
+ "#C1BEAF" # Tundra ❄️
503
+ ]
504
+ sketchpad = gr.ImageEditor(
505
+ type="pil", label="Draw Segmentation Map", height=512, width=512, brush=gr.Brush(colors=terrain_colors))
506
+
507
+ elevation_slider = gr.Slider(
508
+ minimum=0, maximum=90, value=30, step=1, label="Elevation Angle")
509
+ azimuth_slider = gr.Slider(
510
+ minimum=0, maximum=360, value=45, step=1, label="Azimuth Angle")
511
+ btn = gr.Button("Generate")
512
+
513
+ with gr.Column():
514
+ output_heightmap = gr.Image(
515
+ type="pil", label="Generated Heightmap (2D)")
516
+ output_terrain = gr.Image(
517
+ type="pil", label="Generated Terrain (2D)")
518
+ output_plot = gr.Plot(label="Generated Terrain (3D)")
519
+ output_3d_viewer = gr.Model3D(
520
+ label="Generated 3D Object (not particularly accurate)")
521
+ # output_viewer = gr.HTML(label="Interactive Three.js Terrain")
522
+
523
+ # Wrapper function to decide which input to use
524
+ def wrapper_predict(uploaded_img, drawn_img_dict, elevation, azimuth):
525
+ image_to_use = None
526
+ # Check if the user has drawn something meaningful
527
+ if drawn_img_dict and drawn_img_dict["composite"] is not None:
528
+ image_to_use = drawn_img_dict["composite"]
529
+ # Otherwise, fall back to the uploaded image
530
+ elif uploaded_img is not None:
531
+ image_to_use = uploaded_img
532
+
533
+ return predict(image_to_use, elevation, azimuth)
534
+
535
+ # The 'Generate' button triggers the prediction
536
+ btn.click(
537
+ fn=wrapper_predict,
538
+ inputs=[input_img_upload, sketchpad, elevation_slider, azimuth_slider],
539
+ outputs=[output_heightmap, output_terrain,
540
+ output_plot, output_3d_viewer]
541
+ )
542
+
543
+ # Allow sliders to update the plot interactively when released
544
+ elevation_slider.release(
545
+ fn=wrapper_predict,
546
+ inputs=[input_img_upload, sketchpad, elevation_slider, azimuth_slider],
547
+ outputs=[output_heightmap, output_terrain, output_plot]
548
+ )
549
+ azimuth_slider.release(
550
+ fn=wrapper_predict,
551
+ inputs=[input_img_upload, sketchpad, elevation_slider, azimuth_slider],
552
+ outputs=[output_heightmap, output_terrain, output_plot]
553
+ )
554
+
555
+ # Launch the app
556
+ if __name__ == "__main__":
557
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ matplotlib
4
+ trimesh
5
+ pygltflib
6
+ numpy
7
+ seaborn
8
+ gradio
9
+ pillow
train_heightmap.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.optim as optim
4
+ import torch.nn as nn
5
+ from util.unet import UNet
6
+ import torchvision.transforms as transforms
7
+ import util.dataset as ds
8
+ from torch.utils.data import random_split
9
+ from torch.utils.data import DataLoader
10
+ import torchvision.models as models
11
+
12
+ # change for your own dataset path.
13
+ # dataset: https://www.kaggle.com/datasets/tpapp157/earth-terrain-height-and-segmentation-map-images
14
+ dataset_path = "../../Other/cosmos/data/terrain_reconstruction/_dataset/"
15
+
16
+
17
+ transform_pipeline = transforms.Compose([
18
+ transforms.Resize((128, 128)),
19
+ transforms.ToTensor(),
20
+ # transforms.Normalize(mean=[0.5], std=[0.5]),
21
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406],
22
+ # std=[0.229, 0.224, 0.225])
23
+ ])
24
+
25
+ dataset = ds.TerrainDataset(dataset_path, transform=transform_pipeline)
26
+
27
+ # Example: 80% train, 20% test
28
+ train_size = int(0.8 * len(dataset))
29
+ test_size = len(dataset) - train_size
30
+ dataset_train, dataset_test = random_split(dataset, [train_size, test_size])
31
+
32
+ # from unet import UNet
33
+ device = torch.device("mps" if torch.backends.mps.is_available(
34
+ ) else "cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # initialize dataloaders
37
+ numworkers = 0
38
+ batchsize = 8
39
+ train_loader = DataLoader(
40
+ dataset_train, batch_size=batchsize, shuffle=True, num_workers=numworkers)
41
+ test_loader = DataLoader(dataset_test, batch_size=batchsize,
42
+ shuffle=False, num_workers=numworkers)
43
+
44
+
45
+ class PerceptualLoss(nn.Module):
46
+ def __init__(self, feature_layer=9):
47
+ super(PerceptualLoss, self).__init__()
48
+ vgg = models.vgg16(
49
+ weights=models.VGG16_Weights.DEFAULT).features[:feature_layer].eval()
50
+ for param in vgg.parameters():
51
+ param.requires_grad = False
52
+ self.vgg = vgg.to(device)
53
+ self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
54
+ std=[0.229, 0.224, 0.225])
55
+
56
+ def forward(self, pred, target):
57
+ pred = self.transform(pred)
58
+ target = self.transform(target)
59
+ return nn.functional.mse_loss(self.vgg(pred), self.vgg(target))
60
+
61
+
62
+ def total_variation_loss(x):
63
+ return torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + \
64
+ torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
65
+
66
+
67
+ unet_model = UNet(in_channels=3, out_channels=1, use_sigmoid=False, features=[
68
+ 64, 128, 256, 512, 1024]).to(device)
69
+
70
+ mse_loss = nn.MSELoss()
71
+ perceptual_loss = PerceptualLoss().to(device)
72
+ perceptual_loss_scaling_factor = 0.1
73
+ optimizer = optim.Adam(unet_model.parameters(), lr=0.001)
74
+
75
+
76
+ # unet_model.load_state_dict(torch.load('./models/terrain/heightmap_unet_model.pth'))
77
+ num_epochs = 5
78
+ for epoch in range(num_epochs):
79
+ unet_model.train()
80
+ running_loss = 0.0
81
+
82
+ for i, (height, terrain, segmentation) in enumerate(train_loader):
83
+ images = segmentation
84
+ images = images.to(device).float()
85
+ target_images = height
86
+ target_images = target_images.to(device).float()
87
+
88
+ # Forward pass
89
+ outputs = unet_model(images)
90
+ # print(f"Outputs shape: {outputs.shape}, Target shape: {target_images.shape}")
91
+ # print(f"outputs {outputs}")
92
+ # print(f"target_images {target_images}")
93
+ # loss = criterion(outputs, target_images)
94
+ # Convert [B, 1, H, W] → [B, 3, H, W]
95
+
96
+ outputs_rgb = outputs.repeat(1, 3, 1, 1)
97
+ targets_rgb = target_images.repeat(1, 3, 1, 1)
98
+ # loss = mse_loss(outputs/65535, target_images/65535) + perceptual_loss(outputs/65535, target_images/65535) * perceptual_loss_scaling_factor
99
+ tv_weight = 1e-6
100
+ loss = (mse_loss(outputs/65535, target_images/65535) + perceptual_loss_scaling_factor *
101
+ perceptual_loss(outputs_rgb/65535, targets_rgb/65535) + tv_weight * total_variation_loss(outputs/65535))
102
+ # TODO: ADD PERCEPTUAL LOSS
103
+ running_loss += loss.item()
104
+ # Backward pass and optimization
105
+ optimizer.zero_grad()
106
+ loss.backward()
107
+ optimizer.step()
108
+ if (i + 1) % 10 == 0:
109
+ print('Epoch ', (epoch + 1/num_epochs), "Step",
110
+ ((i + 1)/len(train_loader)), "Loss:", (loss.item()))
111
+
112
+ torch.save(unet_model.state_dict(),
113
+ './models/terrain/turbo_heightmap_unet_model.pth')
114
+ print("Model saved to './models/terrain/turbo_heightmap_unet_model.pth'")
train_terrain.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.optim as optim
4
+ import torch.nn as nn
5
+ from util.unet import UNet
6
+ import torchvision.transforms as transforms
7
+ import util.dataset as ds
8
+ from torch.utils.data import random_split
9
+ from torch.utils.data import DataLoader
10
+ import torchvision.models as models
11
+
12
+
13
+ dataset_path = "../../Other/cosmos/data/terrain_reconstruction/_dataset/"
14
+
15
+
16
+ transform_pipeline = transforms.Compose([
17
+ transforms.Resize((128, 128)),
18
+ transforms.ToTensor(),
19
+ # transforms.Normalize(mean=[0.5], std=[0.5]),
20
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406],
21
+ # std=[0.229, 0.224, 0.225])
22
+ ])
23
+
24
+ dataset = ds.TerrainDataset(dataset_path, transform=transform_pipeline)
25
+
26
+ # Example: 80% train, 20% test
27
+ train_size = int(0.8 * len(dataset))
28
+ test_size = len(dataset) - train_size
29
+ dataset_train, dataset_test = random_split(dataset, [train_size, test_size])
30
+
31
+ # from unet import UNet
32
+ device = torch.device("mps" if torch.backends.mps.is_available(
33
+ ) else "cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+ # initialize dataloaders
36
+ numworkers = 0
37
+ batchsize = 8
38
+ train_loader = DataLoader(
39
+ dataset_train, batch_size=batchsize, shuffle=True, num_workers=numworkers)
40
+ test_loader = DataLoader(dataset_test, batch_size=batchsize,
41
+ shuffle=False, num_workers=numworkers)
42
+
43
+
44
+ class PerceptualLoss(nn.Module):
45
+ def __init__(self, feature_layer=9):
46
+ super(PerceptualLoss, self).__init__()
47
+ vgg = models.vgg16(
48
+ weights=models.VGG16_Weights.DEFAULT).features[:feature_layer].eval()
49
+ for param in vgg.parameters():
50
+ param.requires_grad = False
51
+ self.vgg = vgg.to(device)
52
+ self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
53
+ std=[0.229, 0.224, 0.225])
54
+
55
+ def forward(self, pred, target):
56
+ pred = self.transform(pred)
57
+ target = self.transform(target)
58
+ return nn.functional.mse_loss(self.vgg(pred), self.vgg(target))
59
+
60
+
61
+ def total_variation_loss(x):
62
+ return torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + \
63
+ torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
64
+
65
+
66
+ # Initialize UNet model
67
+ unet_model = UNet(in_channels=3, out_channels=3).to(device)
68
+ # criterion = nn.MSELoss()
69
+ mse_loss = nn.MSELoss()
70
+ perceptual_loss = PerceptualLoss().to(device)
71
+ perceptual_loss_scaling_factor = 0.1 # Adjust this factor based on your needs
72
+ optimizer = optim.Adam(unet_model.parameters(), lr=0.001)
73
+
74
+
75
+ train_previous = False
76
+ if train_previous:
77
+ unet_model.load_state_dict(torch.load(
78
+ './models/terrain/turbo_terrain_unet_model.pth'))
79
+ print("Loaded previous model state from './models/terrain/turbo_terrain_unet_model.pth'")
80
+
81
+ num_epochs = 5
82
+ for epoch in range(num_epochs):
83
+ # save model to checkpoints
84
+ # torch.save(unet_model.state_dict(
85
+ # ), f'./models/checkpoints/terrain/turbo_terrain_unet_model_epoch_{epoch + 1}.pth')
86
+ # unet_model.train()
87
+ running_loss = 0.0
88
+
89
+ for i, (height, terrain, segmentation) in enumerate(train_loader):
90
+ terrain = (terrain * 2) - 1 # if originally ∈ [0,1]
91
+ # CHECK ABOVE LINE
92
+ images = segmentation
93
+ images = images.to(device)
94
+ target_images = terrain
95
+ target_images = target_images.to(device)
96
+
97
+ # Forward pass
98
+ outputs = unet_model(images)
99
+ # loss = criterion(outputs, target_images)
100
+ loss = mse_loss(outputs, target_images) + perceptual_loss_scaling_factor * \
101
+ perceptual_loss(outputs, target_images)
102
+ running_loss += loss.item()
103
+ # Backward pass and optimization
104
+ optimizer.zero_grad()
105
+ loss.backward()
106
+ optimizer.step()
107
+ if (i + 1) % 10 == 0:
108
+ # Use end='' to avoid new line
109
+ # print(f'\rEpoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{
110
+ # len(train_loader)}], Loss: {loss.item():.4f}', end='', flush=True)
111
+ print(f"epoch: {epoch+1}")
112
+ print(f"step: {i+1}/{len(train_loader)}")
113
+ print(f"loss: {loss.item():.4f}")
114
+
115
+
116
+ torch.save(unet_model.state_dict(),
117
+ './models/terrain/turbo_terrain_unet_model.pth')
118
+ print("Model saved to './models/terrain/turbo_terrain_unet_model.pth'")
util/__pycache__/dataset.cpython-311.pyc ADDED
Binary file (3.4 kB). View file
 
util/__pycache__/unet.cpython-311.pyc ADDED
Binary file (4.69 kB). View file
 
util/dataset.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import os
3
+ from PIL import Image
4
+
5
+
6
+ class TerrainDataset(Dataset):
7
+ def __init__(self, data_dir, transform=None):
8
+ self.data_dir = data_dir
9
+ self.transform = transform
10
+
11
+ # Sort to ensure alignment between h, t, i files
12
+ self.height_paths = sorted(
13
+ [os.path.join(data_dir, f)
14
+ for f in os.listdir(data_dir) if '_h' in f]
15
+ )
16
+ self.terrain_paths = sorted(
17
+ [os.path.join(data_dir, f)
18
+ for f in os.listdir(data_dir) if '_t' in f]
19
+ )
20
+ self.segmentation_paths = sorted(
21
+ [os.path.join(data_dir, f) for f in os.listdir(
22
+ data_dir) if '_i' in f or '_i2' in f]
23
+ )
24
+
25
+ assert len(self.height_paths) == len(self.terrain_paths) == len(self.segmentation_paths), \
26
+ "Mismatch in dataset triplet lengths"
27
+
28
+ print(f"Found {len(self.height_paths)} triplets in {data_dir}")
29
+
30
+ def __len__(self):
31
+ return len(self.height_paths)
32
+
33
+ def __getitem__(self, idx):
34
+ # Load heightmap, terrain, segmentation
35
+ paths = [self.height_paths[idx], self.terrain_paths[idx],
36
+ self.segmentation_paths[idx]]
37
+ images = []
38
+ for path in paths:
39
+ # image = Image.open(path).convert('RGB')
40
+ image = Image.open(path)
41
+ if self.transform:
42
+ image = self.transform(image)
43
+ images.append(image)
44
+ return tuple(images) # (heightmap, terrain, segmentation)
util/unet.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DoubleConv(nn.Module):
7
+ def __init__(self, out_channels):
8
+ super(DoubleConv, self).__init__()
9
+ self.conv = nn.Sequential(
10
+ nn.LazyConv2d(out_channels, kernel_size=3, padding=1),
11
+ nn.ReLU(inplace=True),
12
+ nn.LazyConv2d(out_channels, kernel_size=3, padding=1),
13
+ nn.ReLU(inplace=True)
14
+ )
15
+
16
+ def forward(self, x):
17
+ return self.conv(x)
18
+
19
+
20
+ class UNet(nn.Module):
21
+ def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512], use_sigmoid=True):
22
+ self.use_sigmoid = use_sigmoid
23
+
24
+ super(UNet, self).__init__()
25
+ self.encoder = nn.ModuleList()
26
+ self.decoder = nn.ModuleList()
27
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
28
+
29
+ # Encoder
30
+ for feature in features:
31
+ self.encoder.append(DoubleConv(feature))
32
+
33
+ # Bottleneck
34
+ self.bottleneck = DoubleConv(features[-1] * 2)
35
+
36
+ # Decoder
37
+ for feature in reversed(features):
38
+ self.decoder.append(nn.ConvTranspose2d(
39
+ feature * 2, feature, kernel_size=2, stride=2))
40
+ self.decoder.append(DoubleConv(feature)) # after concatenation
41
+
42
+ self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
43
+ self.output_activation = nn.Sigmoid() if out_channels == 1 else nn.Identity()
44
+
45
+ def forward(self, x):
46
+ skip_connections = []
47
+
48
+ for layer in self.encoder:
49
+ x = layer(x)
50
+ skip_connections.append(x)
51
+ x = self.pool(x)
52
+
53
+ x = self.bottleneck(x)
54
+ skip_connections = skip_connections[::-1]
55
+
56
+ for idx in range(0, len(self.decoder), 2):
57
+ x = self.decoder[idx](x) # upsample
58
+ skip_connection = skip_connections[idx // 2]
59
+ if x.shape != skip_connection.shape:
60
+ x = F.interpolate(
61
+ x, size=skip_connection.shape[2:], mode='bilinear', align_corners=True)
62
+ x = torch.cat((skip_connection, x), dim=1) # concat
63
+ x = self.decoder[idx + 1](x) # double conv
64
+
65
+ # return self.final_conv(x)
66
+ if (self.use_sigmoid):
67
+ return self.output_activation(self.final_conv(x))
68
+ else:
69
+ return self.final_conv(x)