import os os.environ.setdefault("HF_HOME", "/tmp/hf_cache") os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules") os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib") os.environ.setdefault("GRADIO_TEMP_DIR", "/tmp/gradio") import spaces import gradio as gr import numpy as np import torch from PIL import Image from starlette.responses import FileResponse, HTMLResponse, JSONResponse import base64 import io import math import time import uuid from pathlib import Path from terrain_diffusion.inference.relief_map import get_relief_map from terrain_diffusion.inference.world_pipeline import WorldPipeline MODEL_ID = "xandergos/terrain-diffusion-30m" PAPER_URL = "https://arxiv.org/abs/2512.08309" MODEL_URL = "https://huggingface.co/xandergos/terrain-diffusion-30m" OUTPUT_DIR = Path("/tmp/terrain_outputs") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) PIXEL_RESOLUTION_M = 30 PADDING = 4 GRID_SIZE = 4 TOTAL = GRID_SIZE + 2 * PADDING MESH_MAX_SAMPLES = 160 _world = None def _get_world(): global _world if _world is not None: return _world print("Loading terrain diffusion model...") t0 = time.time() _world = WorldPipeline.from_pretrained( MODEL_ID, seed=42, latents_batch_size=[1, 2, 4, 8, 16], log_mode="info", torch_compile=False, dtype=None, caching_strategy="direct", cache_limit=100 * 1024 * 1024, ) _world.to("cpu") _world.bind(hdf5_file=None) print(f"Model loaded in {time.time() - t0:.1f}s, seed={_world.seed}") return _world @spaces.GPU(duration=1) def _zerogpu_probe(): return "ready" def _empty_grid(default_val=-1000.0): return np.full((TOTAL, TOTAL), default_val, dtype=np.float32) def _mountains_grid(): g = _empty_grid() cx = TOTAL // 2 for i in range(TOTAL): for j in range(TOTAL): di = (i - cx) / (GRID_SIZE / 2) dj = (j - cx) / (GRID_SIZE / 2) dist = np.sqrt(di**2 + dj**2) if dist < 1.0: g[i, j] = 2500.0 * (1.0 - dist) elif dist < 1.5: g[i, j] = 500.0 * (1.5 - dist) / 0.5 return g def _plains_grid(): g = _empty_grid() cx = TOTAL // 2 for i in range(PADDING, PADDING + GRID_SIZE): for j in range(PADDING, PADDING + GRID_SIZE): di = (i - cx) / (GRID_SIZE / 2) dj = (j - cx) / (GRID_SIZE / 2) g[i, j] = 200.0 + 150.0 * np.sin(di * 3) * np.cos(dj * 3) return g def _islands_grid(): g = _empty_grid(-2000.0) cx = TOTAL // 2 for i in range(TOTAL): for j in range(TOTAL): di = (i - cx) / (GRID_SIZE / 2) dj = (j - cx) / (GRID_SIZE / 2) dist = np.sqrt(di**2 + dj**2) if dist < 0.5: g[i, j] = 800.0 * (1.0 - dist / 0.5) elif dist < 0.8: g[i, j] = 100.0 * (0.8 - dist) / 0.3 return g def _canyon_grid(): g = _empty_grid() cx = TOTAL // 2 for i in range(TOTAL): for j in range(PADDING, PADDING + GRID_SIZE): dist_from_center = abs(j - cx) / (GRID_SIZE / 2) if dist_from_center < 0.3: g[i, j] = -300.0 elif dist_from_center < 0.5: g[i, j] = 1500.0 else: g[i, j] = 800.0 return g PRESETS = { "Mountains": _mountains_grid, "Plains": _plains_grid, "Islands": _islands_grid, "Canyon": _canyon_grid, } PRESET_DESCRIPTIONS = { "Mountains": "High central relief with foothill context.", "Plains": "Low rolling elevation with broad interior variation.", "Islands": "Ocean-biased layout with compact emergent land.", "Canyon": "Incised central valley with raised shoulders.", } def _lerp_color(a, b, t): t = float(np.clip(t, 0.0, 1.0)) return np.array(a, dtype=np.float32) * (1 - t) + np.array(b, dtype=np.float32) * t def _terrain_color(height, vmin, vmax): snow_line = max(1850.0, vmin + (vmax - vmin) * 0.72) if height < 0: t = np.clip((height - vmin) / max(1.0, 0.0 - vmin), 0, 1) return _lerp_color((18, 53, 93), (47, 112, 151), t) if height < 70: return _lerp_color((182, 158, 105), (219, 197, 137), height / 70.0) if height < 950: return _lerp_color((47, 116, 71), (126, 144, 79), (height - 70) / 880.0) if height < snow_line: return _lerp_color((107, 108, 99), (160, 158, 147), (height - 950) / max(1.0, snow_line - 950)) return _lerp_color((218, 224, 219), (250, 252, 247), (height - snow_line) / max(1.0, vmax - snow_line)) def _colorize_elevation(elev): arr = np.asarray(elev, dtype=np.float32) vmin = float(np.nanmin(arr)) vmax = float(np.nanmax(arr)) snow_line = max(1850.0, vmin + (vmax - vmin) * 0.72) rgb = np.zeros(arr.shape + (3,), dtype=np.float32) def apply(mask, low, high, t): if not np.any(mask): return t = np.clip(t, 0.0, 1.0)[..., None] low_arr = np.array(low, dtype=np.float32) high_arr = np.array(high, dtype=np.float32) rgb[mask] = low_arr * (1.0 - t[mask]) + high_arr * t[mask] water = arr < 0 beach = (arr >= 0) & (arr < 70) grass = (arr >= 70) & (arr < 950) rock = (arr >= 950) & (arr < snow_line) snow = arr >= snow_line apply(water, (18, 53, 93), (47, 112, 151), (arr - vmin) / max(1.0, 0.0 - vmin)) apply(beach, (182, 158, 105), (219, 197, 137), arr / 70.0) apply(grass, (47, 116, 71), (126, 144, 79), (arr - 70) / 880.0) apply(rock, (107, 108, 99), (160, 158, 147), (arr - 950) / max(1.0, snow_line - 950)) apply(snow, (218, 224, 219), (250, 252, 247), (arr - snow_line) / max(1.0, vmax - snow_line)) return np.clip(rgb, 0, 255).astype(np.uint8) def _array_to_data_url(arr): image = Image.fromarray(np.asarray(arr, dtype=np.uint8)) buf = io.BytesIO() image.save(buf, format="PNG") encoded = base64.b64encode(buf.getvalue()).decode("ascii") return f"data:image/png;base64,{encoded}" def _grid_to_display(grid): display = np.kron(grid, np.ones((18, 18), dtype=np.float32)) return _colorize_elevation(display) def _decode_sketch_data_url(sketch_data_url): if not sketch_data_url: return None if not sketch_data_url.startswith("data:image"): return None try: _, encoded = sketch_data_url.split(",", 1) raw = base64.b64decode(encoded) return Image.open(io.BytesIO(raw)).convert("L") except Exception as exc: print(f"Could not decode sketch data URL: {exc}") return None def _sketch_to_conditioning(sketch_data_url): sketch = _decode_sketch_data_url(sketch_data_url) if sketch is None: return None arr = np.array(sketch, dtype=np.float32) normalized = arr / 255.0 elev = (normalized - 0.3) * 5000.0 pil_img = Image.fromarray(elev.astype(np.float32), mode="F") pil_resized = pil_img.resize((TOTAL, TOTAL), Image.BILINEAR) return np.array(pil_resized, dtype=np.float32) def _relief_preview(elev): relief = get_relief_map(elev, None, None, None, resolution=PIXEL_RESOLUTION_M) relief_img = (np.clip(relief, 0, 1) * 255).astype(np.uint8) if relief_img.ndim == 2: relief_img = np.repeat(relief_img[:, :, None], 3, axis=2) elif relief_img.shape[2] == 4: relief_img = relief_img[:, :, :3] pil = Image.fromarray(relief_img).resize((512, 512), Image.BILINEAR) return _array_to_data_url(np.array(pil)) def _elevation_preview(elev): colored = _colorize_elevation(elev) pil = Image.fromarray(colored).resize((512, 512), Image.BILINEAR) return _array_to_data_url(np.array(pil)) def _mesh_payload(elev): h, w = elev.shape rows = min(MESH_MAX_SAMPLES, h) cols = min(MESH_MAX_SAMPLES, w) row_idx = np.linspace(0, h - 1, rows).astype(np.int32) col_idx = np.linspace(0, w - 1, cols).astype(np.int32) sampled = elev[np.ix_(row_idx, col_idx)].astype(np.float32) sampled = np.round(sampled, 1) return { "rows": int(sampled.shape[0]), "cols": int(sampled.shape[1]), "values": sampled.tolist(), "size_x_m": float((w - 1) * PIXEL_RESOLUTION_M), "size_y_m": float((h - 1) * PIXEL_RESOLUTION_M), "resolution_m": PIXEL_RESOLUTION_M, } def _write_heightmap(elev): path = OUTPUT_DIR / f"heightmap-{uuid.uuid4().hex}.png" elev_clipped = np.clip(elev, 0, 65535).astype(np.uint16) Image.fromarray(elev_clipped).save(path) return f"/heightmaps/{path.name}" def _terrain_stats(elev, elapsed_s): h, w = elev.shape width_km = w * PIXEL_RESOLUTION_M / 1000.0 height_km = h * PIXEL_RESOLUTION_M / 1000.0 area_km2 = width_km * height_km return { "min_m": round(float(np.min(elev)), 1), "max_m": round(float(np.max(elev)), 1), "mean_m": round(float(np.mean(elev)), 1), "area_km2": round(area_km2, 1), "width_km": round(width_km, 2), "height_km": round(height_km, 2), "resolution_m": PIXEL_RESOLUTION_M, "elapsed_s": round(float(elapsed_s), 2), } def _condition_grid(preset_choice, input_mode, sketch_data_url): if input_mode == "Custom Sketch": cond_elev = _sketch_to_conditioning(sketch_data_url) if cond_elev is not None: return cond_elev, "Custom Sketch" preset_fn = PRESETS.get(preset_choice, PRESETS["Mountains"]) return preset_fn(), preset_choice @spaces.GPU(duration=10) def _run_generation(preset_choice, seed, input_mode, sketch_data_url): started = time.perf_counter() world = _get_world() world.to("cuda") try: if seed is not None and int(seed) != world.seed: world.change_seed(int(seed)) grid, conditioning_source = _condition_grid(preset_choice, input_mode, sketch_data_url) world.set_custom_conditioning_import(0, grid, 0, 0, default_value=-1000.0) world.set_cond_snr([0.5, 0.5, 0.5, 0.5, 0.5]) pi1 = PADDING * 256 pi2 = (PADDING + GRID_SIZE) * 256 with world: result = world.get(pi1, pi1, pi2, pi2, with_climate=False) elev = result["elev"].cpu().numpy() elapsed_s = time.perf_counter() - started print(f"Generated terrain in {elapsed_s:.2f}s, seed={int(seed)}, source={conditioning_source}") return { "ok": True, "seed": int(seed), "preset": preset_choice, "conditioning_source": conditioning_source, "mesh": _mesh_payload(elev), "stats": _terrain_stats(elev, elapsed_s), "relief_image": _relief_preview(elev), "elevation_image": _elevation_preview(elev), "conditioning_image": _array_to_data_url(_grid_to_display(grid)), "heightmap_url": _write_heightmap(elev), } finally: world.to("cpu") def _preset_payload(): payload = [] for name, fn in PRESETS.items(): grid = fn() payload.append( { "name": name, "description": PRESET_DESCRIPTIONS[name], "preview": _array_to_data_url(_grid_to_display(grid)), "min_m": round(float(grid.min()), 1), "max_m": round(float(grid.max()), 1), } ) return payload def _model_info_payload(): return { "model_id": MODEL_ID, "model_url": MODEL_URL, "paper_url": PAPER_URL, "paper_title": "InfiniteDiffusion: Bridging Learned Fidelity and Procedural Utility for Open-World Terrain Generation", "author": "Alexander Goslin", "resolution": "30 m/pixel checkpoint", "pipeline": [ { "name": "Coarse planetary model", "text": "A rough procedural or user-provided layout is refined into continental-scale elevation and climate context.", }, { "name": "Core latent model", "text": "A base latent diffusion stage turns coarse context into coherent low-frequency terrain structure and residual latents.", }, { "name": "Consistency decoder", "text": "The decoder expands latents into high-frequency elevation detail for the final heightmap.", }, ], "why_diffusion": "Perlin-style noise is fast, seed-consistent, and unbounded, but tends to produce smooth stationary texture. Terrain Diffusion keeps procedural utility while learning ridges, basins, coastlines, and multi-scale structure from global elevation data.", "infinite": "The paper's InfiniteDiffusion algorithm queries finite regions of an implicit unbounded world with seed consistency and constant-time random access for fixed-size regions.", } INDEX_HTML = r""" Terrain Diffusion Demo
Generating terrain ZeroGPU is running the diffusion pipeline.
Elevation
Loading interface

Terrain Diffusion Demo

xandergos/terrain-diffusion-30m

InfiniteDiffusion: Bridging Learned Fidelity and Procedural Utility for Open-World Terrain Generation

""" app = gr.Server(title="Terrain Diffusion Demo") demo = app @app.get("/", response_class=HTMLResponse) def index(): return HTMLResponse(INDEX_HTML) @app.get("/healthz") def healthz(): return JSONResponse({"ok": True, "model": MODEL_ID}) @app.get("/heightmaps/{filename}") def heightmap(filename: str): safe_name = Path(filename).name path = OUTPUT_DIR / safe_name if not path.exists(): return JSONResponse({"error": "heightmap not found"}, status_code=404) return FileResponse(path, media_type="image/png", filename=safe_name) @app.api(name="get_presets", queue=False) def get_presets() -> dict: return {"presets": _preset_payload()} @app.api(name="get_model_info", queue=False) def get_model_info() -> dict: return _model_info_payload() @app.api(name="generate_terrain", concurrency_limit=1, time_limit=180) def generate_terrain(preset_choice: str, seed: int, input_mode: str, sketch_data_url: str) -> dict: return _run_generation(preset_choice, seed, input_mode, sketch_data_url) if __name__ == "__main__": _get_world() app.launch(server_name="0.0.0.0")