Mike0021's picture
Tune ZeroGPU duration after live measurements
c33f538 verified
Raw
History Blame Contribute Delete
48.1 kB
import os
os.environ.setdefault("HF_HOME", "/tmp/hf_cache")
os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules")
os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib")
os.environ.setdefault("GRADIO_TEMP_DIR", "/tmp/gradio")
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image
from starlette.responses import FileResponse, HTMLResponse, JSONResponse
import base64
import io
import math
import time
import uuid
from pathlib import Path
from terrain_diffusion.inference.relief_map import get_relief_map
from terrain_diffusion.inference.world_pipeline import WorldPipeline
MODEL_ID = "xandergos/terrain-diffusion-30m"
PAPER_URL = "https://arxiv.org/abs/2512.08309"
MODEL_URL = "https://huggingface.co/xandergos/terrain-diffusion-30m"
OUTPUT_DIR = Path("/tmp/terrain_outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
PIXEL_RESOLUTION_M = 30
PADDING = 4
GRID_SIZE = 4
TOTAL = GRID_SIZE + 2 * PADDING
MESH_MAX_SAMPLES = 160
_world = None
def _get_world():
global _world
if _world is not None:
return _world
print("Loading terrain diffusion model...")
t0 = time.time()
_world = WorldPipeline.from_pretrained(
MODEL_ID,
seed=42,
latents_batch_size=[1, 2, 4, 8, 16],
log_mode="info",
torch_compile=False,
dtype=None,
caching_strategy="direct",
cache_limit=100 * 1024 * 1024,
)
_world.to("cpu")
_world.bind(hdf5_file=None)
print(f"Model loaded in {time.time() - t0:.1f}s, seed={_world.seed}")
return _world
@spaces.GPU(duration=1)
def _zerogpu_probe():
return "ready"
def _empty_grid(default_val=-1000.0):
return np.full((TOTAL, TOTAL), default_val, dtype=np.float32)
def _mountains_grid():
g = _empty_grid()
cx = TOTAL // 2
for i in range(TOTAL):
for j in range(TOTAL):
di = (i - cx) / (GRID_SIZE / 2)
dj = (j - cx) / (GRID_SIZE / 2)
dist = np.sqrt(di**2 + dj**2)
if dist < 1.0:
g[i, j] = 2500.0 * (1.0 - dist)
elif dist < 1.5:
g[i, j] = 500.0 * (1.5 - dist) / 0.5
return g
def _plains_grid():
g = _empty_grid()
cx = TOTAL // 2
for i in range(PADDING, PADDING + GRID_SIZE):
for j in range(PADDING, PADDING + GRID_SIZE):
di = (i - cx) / (GRID_SIZE / 2)
dj = (j - cx) / (GRID_SIZE / 2)
g[i, j] = 200.0 + 150.0 * np.sin(di * 3) * np.cos(dj * 3)
return g
def _islands_grid():
g = _empty_grid(-2000.0)
cx = TOTAL // 2
for i in range(TOTAL):
for j in range(TOTAL):
di = (i - cx) / (GRID_SIZE / 2)
dj = (j - cx) / (GRID_SIZE / 2)
dist = np.sqrt(di**2 + dj**2)
if dist < 0.5:
g[i, j] = 800.0 * (1.0 - dist / 0.5)
elif dist < 0.8:
g[i, j] = 100.0 * (0.8 - dist) / 0.3
return g
def _canyon_grid():
g = _empty_grid()
cx = TOTAL // 2
for i in range(TOTAL):
for j in range(PADDING, PADDING + GRID_SIZE):
dist_from_center = abs(j - cx) / (GRID_SIZE / 2)
if dist_from_center < 0.3:
g[i, j] = -300.0
elif dist_from_center < 0.5:
g[i, j] = 1500.0
else:
g[i, j] = 800.0
return g
PRESETS = {
"Mountains": _mountains_grid,
"Plains": _plains_grid,
"Islands": _islands_grid,
"Canyon": _canyon_grid,
}
PRESET_DESCRIPTIONS = {
"Mountains": "High central relief with foothill context.",
"Plains": "Low rolling elevation with broad interior variation.",
"Islands": "Ocean-biased layout with compact emergent land.",
"Canyon": "Incised central valley with raised shoulders.",
}
def _lerp_color(a, b, t):
t = float(np.clip(t, 0.0, 1.0))
return np.array(a, dtype=np.float32) * (1 - t) + np.array(b, dtype=np.float32) * t
def _terrain_color(height, vmin, vmax):
snow_line = max(1850.0, vmin + (vmax - vmin) * 0.72)
if height < 0:
t = np.clip((height - vmin) / max(1.0, 0.0 - vmin), 0, 1)
return _lerp_color((18, 53, 93), (47, 112, 151), t)
if height < 70:
return _lerp_color((182, 158, 105), (219, 197, 137), height / 70.0)
if height < 950:
return _lerp_color((47, 116, 71), (126, 144, 79), (height - 70) / 880.0)
if height < snow_line:
return _lerp_color((107, 108, 99), (160, 158, 147), (height - 950) / max(1.0, snow_line - 950))
return _lerp_color((218, 224, 219), (250, 252, 247), (height - snow_line) / max(1.0, vmax - snow_line))
def _colorize_elevation(elev):
arr = np.asarray(elev, dtype=np.float32)
vmin = float(np.nanmin(arr))
vmax = float(np.nanmax(arr))
snow_line = max(1850.0, vmin + (vmax - vmin) * 0.72)
rgb = np.zeros(arr.shape + (3,), dtype=np.float32)
def apply(mask, low, high, t):
if not np.any(mask):
return
t = np.clip(t, 0.0, 1.0)[..., None]
low_arr = np.array(low, dtype=np.float32)
high_arr = np.array(high, dtype=np.float32)
rgb[mask] = low_arr * (1.0 - t[mask]) + high_arr * t[mask]
water = arr < 0
beach = (arr >= 0) & (arr < 70)
grass = (arr >= 70) & (arr < 950)
rock = (arr >= 950) & (arr < snow_line)
snow = arr >= snow_line
apply(water, (18, 53, 93), (47, 112, 151), (arr - vmin) / max(1.0, 0.0 - vmin))
apply(beach, (182, 158, 105), (219, 197, 137), arr / 70.0)
apply(grass, (47, 116, 71), (126, 144, 79), (arr - 70) / 880.0)
apply(rock, (107, 108, 99), (160, 158, 147), (arr - 950) / max(1.0, snow_line - 950))
apply(snow, (218, 224, 219), (250, 252, 247), (arr - snow_line) / max(1.0, vmax - snow_line))
return np.clip(rgb, 0, 255).astype(np.uint8)
def _array_to_data_url(arr):
image = Image.fromarray(np.asarray(arr, dtype=np.uint8))
buf = io.BytesIO()
image.save(buf, format="PNG")
encoded = base64.b64encode(buf.getvalue()).decode("ascii")
return f"data:image/png;base64,{encoded}"
def _grid_to_display(grid):
display = np.kron(grid, np.ones((18, 18), dtype=np.float32))
return _colorize_elevation(display)
def _decode_sketch_data_url(sketch_data_url):
if not sketch_data_url:
return None
if not sketch_data_url.startswith("data:image"):
return None
try:
_, encoded = sketch_data_url.split(",", 1)
raw = base64.b64decode(encoded)
return Image.open(io.BytesIO(raw)).convert("L")
except Exception as exc:
print(f"Could not decode sketch data URL: {exc}")
return None
def _sketch_to_conditioning(sketch_data_url):
sketch = _decode_sketch_data_url(sketch_data_url)
if sketch is None:
return None
arr = np.array(sketch, dtype=np.float32)
normalized = arr / 255.0
elev = (normalized - 0.3) * 5000.0
pil_img = Image.fromarray(elev.astype(np.float32), mode="F")
pil_resized = pil_img.resize((TOTAL, TOTAL), Image.BILINEAR)
return np.array(pil_resized, dtype=np.float32)
def _relief_preview(elev):
relief = get_relief_map(elev, None, None, None, resolution=PIXEL_RESOLUTION_M)
relief_img = (np.clip(relief, 0, 1) * 255).astype(np.uint8)
if relief_img.ndim == 2:
relief_img = np.repeat(relief_img[:, :, None], 3, axis=2)
elif relief_img.shape[2] == 4:
relief_img = relief_img[:, :, :3]
pil = Image.fromarray(relief_img).resize((512, 512), Image.BILINEAR)
return _array_to_data_url(np.array(pil))
def _elevation_preview(elev):
colored = _colorize_elevation(elev)
pil = Image.fromarray(colored).resize((512, 512), Image.BILINEAR)
return _array_to_data_url(np.array(pil))
def _mesh_payload(elev):
h, w = elev.shape
rows = min(MESH_MAX_SAMPLES, h)
cols = min(MESH_MAX_SAMPLES, w)
row_idx = np.linspace(0, h - 1, rows).astype(np.int32)
col_idx = np.linspace(0, w - 1, cols).astype(np.int32)
sampled = elev[np.ix_(row_idx, col_idx)].astype(np.float32)
sampled = np.round(sampled, 1)
return {
"rows": int(sampled.shape[0]),
"cols": int(sampled.shape[1]),
"values": sampled.tolist(),
"size_x_m": float((w - 1) * PIXEL_RESOLUTION_M),
"size_y_m": float((h - 1) * PIXEL_RESOLUTION_M),
"resolution_m": PIXEL_RESOLUTION_M,
}
def _write_heightmap(elev):
path = OUTPUT_DIR / f"heightmap-{uuid.uuid4().hex}.png"
elev_clipped = np.clip(elev, 0, 65535).astype(np.uint16)
Image.fromarray(elev_clipped).save(path)
return f"/heightmaps/{path.name}"
def _terrain_stats(elev, elapsed_s):
h, w = elev.shape
width_km = w * PIXEL_RESOLUTION_M / 1000.0
height_km = h * PIXEL_RESOLUTION_M / 1000.0
area_km2 = width_km * height_km
return {
"min_m": round(float(np.min(elev)), 1),
"max_m": round(float(np.max(elev)), 1),
"mean_m": round(float(np.mean(elev)), 1),
"area_km2": round(area_km2, 1),
"width_km": round(width_km, 2),
"height_km": round(height_km, 2),
"resolution_m": PIXEL_RESOLUTION_M,
"elapsed_s": round(float(elapsed_s), 2),
}
def _condition_grid(preset_choice, input_mode, sketch_data_url):
if input_mode == "Custom Sketch":
cond_elev = _sketch_to_conditioning(sketch_data_url)
if cond_elev is not None:
return cond_elev, "Custom Sketch"
preset_fn = PRESETS.get(preset_choice, PRESETS["Mountains"])
return preset_fn(), preset_choice
@spaces.GPU(duration=10)
def _run_generation(preset_choice, seed, input_mode, sketch_data_url):
started = time.perf_counter()
world = _get_world()
world.to("cuda")
try:
if seed is not None and int(seed) != world.seed:
world.change_seed(int(seed))
grid, conditioning_source = _condition_grid(preset_choice, input_mode, sketch_data_url)
world.set_custom_conditioning_import(0, grid, 0, 0, default_value=-1000.0)
world.set_cond_snr([0.5, 0.5, 0.5, 0.5, 0.5])
pi1 = PADDING * 256
pi2 = (PADDING + GRID_SIZE) * 256
with world:
result = world.get(pi1, pi1, pi2, pi2, with_climate=False)
elev = result["elev"].cpu().numpy()
elapsed_s = time.perf_counter() - started
print(f"Generated terrain in {elapsed_s:.2f}s, seed={int(seed)}, source={conditioning_source}")
return {
"ok": True,
"seed": int(seed),
"preset": preset_choice,
"conditioning_source": conditioning_source,
"mesh": _mesh_payload(elev),
"stats": _terrain_stats(elev, elapsed_s),
"relief_image": _relief_preview(elev),
"elevation_image": _elevation_preview(elev),
"conditioning_image": _array_to_data_url(_grid_to_display(grid)),
"heightmap_url": _write_heightmap(elev),
}
finally:
world.to("cpu")
def _preset_payload():
payload = []
for name, fn in PRESETS.items():
grid = fn()
payload.append(
{
"name": name,
"description": PRESET_DESCRIPTIONS[name],
"preview": _array_to_data_url(_grid_to_display(grid)),
"min_m": round(float(grid.min()), 1),
"max_m": round(float(grid.max()), 1),
}
)
return payload
def _model_info_payload():
return {
"model_id": MODEL_ID,
"model_url": MODEL_URL,
"paper_url": PAPER_URL,
"paper_title": "InfiniteDiffusion: Bridging Learned Fidelity and Procedural Utility for Open-World Terrain Generation",
"author": "Alexander Goslin",
"resolution": "30 m/pixel checkpoint",
"pipeline": [
{
"name": "Coarse planetary model",
"text": "A rough procedural or user-provided layout is refined into continental-scale elevation and climate context.",
},
{
"name": "Core latent model",
"text": "A base latent diffusion stage turns coarse context into coherent low-frequency terrain structure and residual latents.",
},
{
"name": "Consistency decoder",
"text": "The decoder expands latents into high-frequency elevation detail for the final heightmap.",
},
],
"why_diffusion": "Perlin-style noise is fast, seed-consistent, and unbounded, but tends to produce smooth stationary texture. Terrain Diffusion keeps procedural utility while learning ridges, basins, coastlines, and multi-scale structure from global elevation data.",
"infinite": "The paper's InfiniteDiffusion algorithm queries finite regions of an implicit unbounded world with seed consistency and constant-time random access for fixed-size regions.",
}
INDEX_HTML = r"""
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Terrain Diffusion Demo</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/controls/OrbitControls.js"></script>
<style>
:root {
color-scheme: dark;
--bg: #0d1112;
--panel: #151b1d;
--panel-2: #1c2425;
--line: #334044;
--line-soft: #263135;
--text: #eef2ed;
--muted: #a9b4ae;
--green: #6fb36f;
--cyan: #70b6c7;
--amber: #d2a85f;
--rust: #c56d4d;
--blue: #3e78a8;
}
* {
box-sizing: border-box;
}
body {
margin: 0;
min-height: 100vh;
overflow: hidden;
background: var(--bg);
color: var(--text);
font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
letter-spacing: 0;
}
button, input {
font: inherit;
}
a {
color: var(--cyan);
}
.app-shell {
display: grid;
grid-template-columns: minmax(0, 1fr) 390px;
height: 100vh;
width: 100vw;
}
.viewport {
position: relative;
min-width: 0;
background: #0a0e0f;
overflow: hidden;
}
#sceneHost {
position: absolute;
inset: 0;
}
.scene-toolbar {
position: absolute;
left: 18px;
right: 18px;
bottom: 18px;
display: flex;
align-items: center;
gap: 10px;
flex-wrap: wrap;
padding: 10px;
border: 1px solid rgba(130, 151, 151, 0.28);
background: rgba(15, 20, 21, 0.74);
backdrop-filter: blur(12px);
border-radius: 8px;
}
.toolbar-group {
display: flex;
align-items: center;
gap: 8px;
min-width: 0;
}
.toolbar-label {
color: var(--muted);
font-size: 12px;
white-space: nowrap;
}
.range {
width: 140px;
accent-color: var(--green);
}
.status-pill {
min-height: 36px;
display: inline-flex;
align-items: center;
gap: 8px;
padding: 0 12px;
border: 1px solid var(--line-soft);
border-radius: 8px;
color: var(--muted);
background: rgba(12, 16, 17, 0.72);
font-size: 13px;
margin-left: auto;
}
.dot {
width: 8px;
height: 8px;
border-radius: 50%;
background: var(--amber);
box-shadow: 0 0 14px rgba(210, 168, 95, 0.55);
flex: 0 0 auto;
}
.dot.ready {
background: var(--green);
box-shadow: 0 0 14px rgba(111, 179, 111, 0.55);
}
.dot.error {
background: var(--rust);
box-shadow: 0 0 14px rgba(197, 109, 77, 0.55);
}
.sidebar {
min-width: 0;
height: 100vh;
overflow: auto;
border-left: 1px solid var(--line);
background: var(--panel);
}
.side-inner {
display: flex;
flex-direction: column;
gap: 18px;
padding: 20px;
}
.brand {
display: flex;
flex-direction: column;
gap: 6px;
border-bottom: 1px solid var(--line-soft);
padding-bottom: 16px;
}
.eyebrow {
margin: 0;
color: var(--green);
font-size: 12px;
font-weight: 700;
text-transform: uppercase;
}
h1, h2, h3, p {
margin: 0;
}
.brand h1 {
font-size: 22px;
line-height: 1.15;
font-weight: 760;
}
.subcopy {
color: var(--muted);
font-size: 13px;
line-height: 1.45;
}
.panel {
display: flex;
flex-direction: column;
gap: 12px;
}
.panel h2 {
font-size: 13px;
text-transform: uppercase;
color: #ccd8d0;
}
.segmented {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 6px;
padding: 4px;
border: 1px solid var(--line-soft);
border-radius: 8px;
background: #101617;
}
.segmented button,
.tool-btn,
.generate,
.landing button,
.download-link {
border: 1px solid var(--line);
border-radius: 8px;
min-height: 38px;
color: var(--text);
background: var(--panel-2);
cursor: pointer;
transition: border-color 160ms ease, background 160ms ease, transform 160ms ease;
}
.segmented button.active {
background: #254235;
border-color: rgba(111, 179, 111, 0.85);
}
.tool-btn {
display: inline-flex;
align-items: center;
justify-content: center;
padding: 0 12px;
font-size: 13px;
white-space: nowrap;
}
.tool-btn.active {
border-color: var(--green);
background: #223b31;
}
.generate {
width: 100%;
min-height: 44px;
background: #315f46;
border-color: rgba(111, 179, 111, 0.78);
font-weight: 760;
}
.generate:disabled {
opacity: 0.62;
cursor: wait;
}
.download-link {
display: none;
align-items: center;
justify-content: center;
min-height: 40px;
color: var(--text);
text-decoration: none;
background: #202b31;
border-color: rgba(112, 182, 199, 0.6);
font-weight: 700;
}
.preset-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 10px;
}
.preset-card {
display: grid;
grid-template-rows: 82px auto;
overflow: hidden;
border: 1px solid var(--line-soft);
border-radius: 8px;
background: #111718;
cursor: pointer;
min-width: 0;
text-align: left;
padding: 0;
}
.preset-card.active {
border-color: var(--green);
box-shadow: inset 0 0 0 1px rgba(111, 179, 111, 0.35);
}
.preset-card img {
width: 100%;
height: 82px;
object-fit: cover;
display: block;
}
.preset-card span {
padding: 9px;
color: var(--text);
font-size: 13px;
font-weight: 700;
overflow-wrap: anywhere;
}
.sketch-tools {
display: none;
gap: 10px;
flex-direction: column;
}
.sketch-wrap {
border: 1px solid var(--line-soft);
border-radius: 8px;
background: #101617;
padding: 10px;
}
#sketchCanvas {
display: block;
width: 100%;
aspect-ratio: 1;
border-radius: 6px;
background: #4d4d4d;
touch-action: none;
}
.brush-row {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 8px;
}
.field {
display: flex;
flex-direction: column;
gap: 7px;
}
.field label {
color: var(--muted);
font-size: 12px;
font-weight: 700;
text-transform: uppercase;
}
.number-input {
width: 100%;
min-height: 40px;
color: var(--text);
border: 1px solid var(--line-soft);
border-radius: 8px;
background: #101617;
padding: 0 12px;
}
.stats-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 8px;
}
.stat {
padding: 10px;
border: 1px solid var(--line-soft);
border-radius: 8px;
background: #101617;
min-width: 0;
}
.stat span {
display: block;
color: var(--muted);
font-size: 11px;
text-transform: uppercase;
margin-bottom: 5px;
}
.stat strong {
display: block;
color: var(--text);
font-size: 17px;
overflow-wrap: anywhere;
}
.preview-tabs {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 6px;
}
.preview-tabs button {
min-height: 34px;
border: 1px solid var(--line-soft);
border-radius: 8px;
color: var(--muted);
background: #101617;
cursor: pointer;
font-size: 12px;
}
.preview-tabs button.active {
color: var(--text);
border-color: var(--cyan);
}
.preview-img {
display: none;
width: 100%;
aspect-ratio: 1;
object-fit: cover;
border: 1px solid var(--line-soft);
border-radius: 8px;
background: #101617;
}
.preview-img.active {
display: block;
}
.about {
gap: 10px;
padding-top: 2px;
}
.about p {
color: var(--muted);
font-size: 13px;
line-height: 1.45;
}
.pipeline {
display: grid;
gap: 8px;
}
.stage {
border-left: 3px solid var(--green);
padding: 8px 10px;
background: #101617;
border-radius: 0 8px 8px 0;
}
.stage strong {
display: block;
font-size: 13px;
margin-bottom: 4px;
}
.stage span {
display: block;
color: var(--muted);
font-size: 12px;
line-height: 1.4;
}
.links {
display: flex;
gap: 12px;
flex-wrap: wrap;
font-size: 13px;
}
.landing {
position: fixed;
inset: 0;
z-index: 10;
display: grid;
align-items: center;
padding: 8vh min(7vw, 92px);
background: rgba(5, 8, 9, 0.68);
backdrop-filter: blur(7px);
}
.landing.hidden {
display: none;
}
.landing-copy {
width: min(740px, 100%);
display: flex;
flex-direction: column;
gap: 18px;
}
.landing h1 {
font-size: clamp(38px, 6vw, 82px);
line-height: 0.95;
max-width: 12ch;
font-weight: 820;
}
.landing .paper-title {
color: #d7ded8;
max-width: 720px;
font-size: clamp(16px, 1.8vw, 22px);
line-height: 1.35;
}
.landing button {
width: 150px;
min-height: 46px;
background: #315f46;
border-color: rgba(111, 179, 111, 0.84);
font-weight: 800;
}
.error-box {
display: none;
border: 1px solid rgba(197, 109, 77, 0.7);
border-radius: 8px;
color: #ffd7c8;
background: rgba(86, 37, 27, 0.55);
padding: 10px;
font-size: 13px;
line-height: 1.4;
}
.loading-overlay {
position: absolute;
inset: 0;
display: none;
align-items: center;
justify-content: center;
background: rgba(7, 10, 11, 0.48);
z-index: 4;
}
.loading-overlay.active {
display: flex;
}
.loader {
width: min(420px, calc(100% - 48px));
border: 1px solid rgba(130, 151, 151, 0.32);
border-radius: 8px;
background: rgba(16, 22, 23, 0.86);
padding: 18px;
color: var(--text);
text-align: center;
}
.loader strong {
display: block;
margin-bottom: 6px;
}
.loader span {
color: var(--muted);
font-size: 13px;
}
@media (max-width: 980px) {
body {
overflow: auto;
}
.app-shell {
grid-template-columns: 1fr;
height: auto;
min-height: 100vh;
}
.viewport {
height: 62vh;
min-height: 460px;
}
.sidebar {
height: auto;
border-left: 0;
border-top: 1px solid var(--line);
}
.status-pill {
margin-left: 0;
width: 100%;
}
}
@media (max-width: 520px) {
.side-inner {
padding: 14px;
}
.preset-grid,
.stats-grid {
grid-template-columns: 1fr;
}
.scene-toolbar {
left: 10px;
right: 10px;
bottom: 10px;
}
.range {
width: 112px;
}
}
</style>
</head>
<body>
<main class="app-shell">
<section class="viewport" aria-label="Interactive terrain viewer">
<div id="sceneHost"></div>
<div class="loading-overlay" id="loadingOverlay">
<div class="loader">
<strong>Generating terrain</strong>
<span>ZeroGPU is running the diffusion pipeline.</span>
</div>
</div>
<div class="scene-toolbar">
<div class="toolbar-group">
<span class="toolbar-label">Elevation</span>
<input class="range" id="exaggeration" type="range" min="0.2" max="4" value="1.55" step="0.05">
</div>
<button class="tool-btn" id="wireToggle" type="button">Wireframe</button>
<button class="tool-btn" id="resetCamera" type="button">Reset view</button>
<div class="status-pill">
<span class="dot" id="statusDot"></span>
<span id="statusText">Loading interface</span>
</div>
</div>
</section>
<aside class="sidebar">
<div class="side-inner">
<header class="brand">
<p class="eyebrow">Terrain Diffusion Demo</p>
<h1>xandergos/terrain-diffusion-30m</h1>
<p class="subcopy">Diffusion-generated heightmaps rendered as an interactive WebGL terrain mesh.</p>
</header>
<section class="panel">
<h2>Input</h2>
<div class="segmented" role="tablist" aria-label="Input mode">
<button class="active" id="presetMode" type="button">Presets</button>
<button id="sketchMode" type="button">Sketch</button>
</div>
<div class="preset-grid" id="presetGrid"></div>
<div class="sketch-tools" id="sketchTools">
<div class="sketch-wrap">
<canvas id="sketchCanvas" aria-label="Custom elevation sketch"></canvas>
</div>
<div class="brush-row">
<button class="tool-btn active" data-brush="high" type="button">High</button>
<button class="tool-btn" data-brush="low" type="button">Low</button>
<button class="tool-btn" data-brush="mid" type="button">Mid</button>
</div>
<button class="tool-btn" id="clearSketch" type="button">Clear sketch</button>
</div>
<div class="field">
<label for="seed">Seed</label>
<input class="number-input" id="seed" type="number" min="0" max="2147483647" step="1" value="42">
</div>
<button class="generate" id="generateBtn" type="button">Generate Terrain</button>
<a class="download-link" id="downloadLink" href="#" download>Download 16-bit PNG</a>
<div class="error-box" id="errorBox"></div>
</section>
<section class="panel">
<h2>Stats</h2>
<div class="stats-grid">
<div class="stat"><span>Min</span><strong id="minStat">-</strong></div>
<div class="stat"><span>Max</span><strong id="maxStat">-</strong></div>
<div class="stat"><span>Mean</span><strong id="meanStat">-</strong></div>
<div class="stat"><span>Area</span><strong id="areaStat">-</strong></div>
</div>
</section>
<section class="panel">
<h2>Maps</h2>
<div class="preview-tabs">
<button class="active" data-preview="relief" type="button">Relief</button>
<button data-preview="elevation" type="button">Elevation</button>
<button data-preview="conditioning" type="button">Conditioning</button>
</div>
<img class="preview-img active" id="reliefPreview" alt="Shaded relief preview">
<img class="preview-img" id="elevationPreview" alt="Elevation color preview">
<img class="preview-img" id="conditioningPreview" alt="Conditioning map preview">
</section>
<section class="panel about">
<h2>About</h2>
<p id="paperTitle">Loading paper details.</p>
<div class="pipeline" id="pipelineStages"></div>
<p id="whyDiffusion"></p>
<p id="infiniteInfo"></p>
<div class="links">
<a id="paperLink" href="https://arxiv.org/abs/2512.08309" target="_blank" rel="noreferrer">Paper</a>
<a id="modelLink" href="https://huggingface.co/xandergos/terrain-diffusion-30m" target="_blank" rel="noreferrer">Model</a>
</div>
</section>
</div>
</aside>
</main>
<section class="landing" id="landing">
<div class="landing-copy">
<p class="eyebrow">Terrain Diffusion Demo</p>
<h1>xandergos/terrain-diffusion-30m</h1>
<p class="paper-title">InfiniteDiffusion: Bridging Learned Fidelity and Procedural Utility for Open-World Terrain Generation</p>
<button id="enterBtn" type="button">Enter</button>
</div>
</section>
<script>
const state = {
mode: "Presets",
selectedPreset: "Mountains",
brush: "high",
currentMesh: null,
baseHeights: null,
terrainMesh: null,
waterMesh: null,
scene: null,
camera: null,
renderer: null,
controls: null,
material: null
};
const els = {
sceneHost: document.getElementById("sceneHost"),
loadingOverlay: document.getElementById("loadingOverlay"),
statusText: document.getElementById("statusText"),
statusDot: document.getElementById("statusDot"),
presetGrid: document.getElementById("presetGrid"),
sketchTools: document.getElementById("sketchTools"),
sketchCanvas: document.getElementById("sketchCanvas"),
seed: document.getElementById("seed"),
generateBtn: document.getElementById("generateBtn"),
errorBox: document.getElementById("errorBox"),
downloadLink: document.getElementById("downloadLink"),
reliefPreview: document.getElementById("reliefPreview"),
elevationPreview: document.getElementById("elevationPreview"),
conditioningPreview: document.getElementById("conditioningPreview"),
minStat: document.getElementById("minStat"),
maxStat: document.getElementById("maxStat"),
meanStat: document.getElementById("meanStat"),
areaStat: document.getElementById("areaStat"),
exaggeration: document.getElementById("exaggeration"),
wireToggle: document.getElementById("wireToggle"),
resetCamera: document.getElementById("resetCamera")
};
function setStatus(text, kind = "busy") {
els.statusText.textContent = text;
els.statusDot.classList.toggle("ready", kind === "ready");
els.statusDot.classList.toggle("error", kind === "error");
}
function setError(message) {
els.errorBox.style.display = message ? "block" : "none";
els.errorBox.textContent = message || "";
if (message) setStatus("Error", "error");
}
async function callApi(name, data = []) {
const post = await fetch(`/gradio_api/call/${name}`, {
method: "POST",
headers: {"Content-Type": "application/json"},
body: JSON.stringify({data})
});
if (!post.ok) {
throw new Error(`API ${name} failed with HTTP ${post.status}`);
}
const event = await post.json();
if (!event.event_id) {
throw new Error(`API ${name} did not return an event id`);
}
return new Promise((resolve, reject) => {
const source = new EventSource(`/gradio_api/call/${name}/${event.event_id}`);
let settled = false;
source.addEventListener("complete", (evt) => {
settled = true;
source.close();
const payload = JSON.parse(evt.data);
resolve(payload[0]);
});
source.addEventListener("error", (evt) => {
if (settled) return;
settled = true;
source.close();
let detail = `API ${name} stream failed`;
if (evt.data) {
try {
detail = JSON.stringify(JSON.parse(evt.data));
} catch {
detail = evt.data;
}
}
reject(new Error(detail));
});
});
}
function initThree() {
if (!window.THREE) {
setError("Three.js did not load.");
return;
}
state.scene = new THREE.Scene();
state.scene.background = new THREE.Color(0x0a0e0f);
state.camera = new THREE.PerspectiveCamera(48, 1, 0.01, 200);
state.camera.position.set(25, 18, 28);
state.renderer = new THREE.WebGLRenderer({antialias: true, powerPreference: "high-performance"});
state.renderer.setPixelRatio(Math.min(window.devicePixelRatio || 1, 2));
state.renderer.shadowMap.enabled = true;
els.sceneHost.appendChild(state.renderer.domElement);
const hemi = new THREE.HemisphereLight(0xd8f2ff, 0x1b241b, 0.72);
state.scene.add(hemi);
const sun = new THREE.DirectionalLight(0xfff0d0, 1.28);
sun.position.set(28, 34, 18);
state.scene.add(sun);
const grid = new THREE.GridHelper(36, 18, 0x334044, 0x1f2a2d);
grid.position.y = -0.04;
state.scene.add(grid);
if (THREE.OrbitControls) {
state.controls = new THREE.OrbitControls(state.camera, state.renderer.domElement);
state.controls.enableDamping = true;
state.controls.dampingFactor = 0.08;
state.controls.target.set(0, 0, 0);
}
window.addEventListener("resize", resizeThree);
resizeThree();
renderPlaceholderTerrain();
animate();
}
function resizeThree() {
if (!state.renderer || !state.camera) return;
const rect = els.sceneHost.getBoundingClientRect();
state.camera.aspect = Math.max(1, rect.width) / Math.max(1, rect.height);
state.camera.updateProjectionMatrix();
state.renderer.setSize(rect.width, rect.height, false);
}
function animate() {
requestAnimationFrame(animate);
if (state.controls) state.controls.update();
if (state.renderer && state.scene && state.camera) {
state.renderer.render(state.scene, state.camera);
}
}
function terrainColor(height, minH, maxH) {
const snowLine = Math.max(1850, minH + (maxH - minH) * 0.72);
if (height < 0) return new THREE.Color(0x2f7097).lerp(new THREE.Color(0x12355d), Math.min(1, Math.abs(height) / Math.max(1, Math.abs(minH))));
if (height < 70) return new THREE.Color(0xdbbf83);
if (height < 950) return new THREE.Color(0x2f7447).lerp(new THREE.Color(0x7e904f), (height - 70) / 880);
if (height < snowLine) return new THREE.Color(0x6b6c63).lerp(new THREE.Color(0xa09e93), (height - 950) / Math.max(1, snowLine - 950));
return new THREE.Color(0xdae0db).lerp(new THREE.Color(0xfafcf7), (height - snowLine) / Math.max(1, maxH - snowLine));
}
function clearTerrain() {
if (state.terrainMesh) {
state.scene.remove(state.terrainMesh);
state.terrainMesh.geometry.dispose();
state.terrainMesh.material.dispose();
state.terrainMesh = null;
}
if (state.waterMesh) {
state.scene.remove(state.waterMesh);
state.waterMesh.geometry.dispose();
state.waterMesh.material.dispose();
state.waterMesh = null;
}
}
function renderPlaceholderTerrain() {
const rows = 96;
const cols = 96;
const values = [];
for (let r = 0; r < rows; r++) {
const row = [];
for (let c = 0; c < cols; c++) {
const x = (c / (cols - 1) - 0.5) * 2;
const y = (r / (rows - 1) - 0.5) * 2;
const ridge = Math.exp(-(x * x * 2.5 + y * y * 8)) * 1550;
const folds = Math.sin((x + y) * 10) * 90 + Math.cos(x * 7) * 120;
row.push(ridge + folds - 180);
}
values.push(row);
}
renderTerrain({rows, cols, values, size_x_m: 30720, size_y_m: 30720});
setStatus("Ready", "ready");
}
function renderTerrain(mesh) {
state.currentMesh = mesh;
state.baseHeights = mesh.values.flat();
clearTerrain();
const rows = mesh.rows;
const cols = mesh.cols;
const sizeX = mesh.size_x_m / 1000;
const sizeY = mesh.size_y_m / 1000;
const heights = state.baseHeights;
const minH = Math.min(...heights);
const maxH = Math.max(...heights);
const exaggeration = Number(els.exaggeration.value);
const positions = [];
const colors = [];
const indices = [];
for (let r = 0; r < rows; r++) {
for (let c = 0; c < cols; c++) {
const idx = r * cols + c;
const x = (c / (cols - 1) - 0.5) * sizeX;
const z = (r / (rows - 1) - 0.5) * sizeY;
const y = (heights[idx] / 1000) * exaggeration;
positions.push(x, y, z);
const color = terrainColor(heights[idx], minH, maxH);
colors.push(color.r, color.g, color.b);
}
}
for (let r = 0; r < rows - 1; r++) {
for (let c = 0; c < cols - 1; c++) {
const a = r * cols + c;
const b = a + 1;
const d = (r + 1) * cols + c;
const e = d + 1;
indices.push(a, d, b, b, d, e);
}
}
const geometry = new THREE.BufferGeometry();
geometry.setIndex(indices);
geometry.setAttribute("position", new THREE.Float32BufferAttribute(positions, 3));
geometry.setAttribute("color", new THREE.Float32BufferAttribute(colors, 3));
geometry.computeVertexNormals();
const material = new THREE.MeshStandardMaterial({
vertexColors: true,
roughness: 0.88,
metalness: 0,
side: THREE.DoubleSide,
wireframe: els.wireToggle.classList.contains("active")
});
state.terrainMesh = new THREE.Mesh(geometry, material);
state.scene.add(state.terrainMesh);
const waterGeometry = new THREE.PlaneGeometry(sizeX, sizeY);
const waterMaterial = new THREE.MeshStandardMaterial({
color: 0x295c78,
transparent: true,
opacity: 0.34,
roughness: 0.55,
metalness: 0
});
state.waterMesh = new THREE.Mesh(waterGeometry, waterMaterial);
state.waterMesh.rotation.x = -Math.PI / 2;
state.waterMesh.position.y = 0;
state.scene.add(state.waterMesh);
}
function updateExaggeration() {
if (!state.terrainMesh || !state.baseHeights || !state.currentMesh) return;
const positions = state.terrainMesh.geometry.attributes.position;
const exaggeration = Number(els.exaggeration.value);
for (let i = 0; i < state.baseHeights.length; i++) {
positions.setY(i, (state.baseHeights[i] / 1000) * exaggeration);
}
positions.needsUpdate = true;
state.terrainMesh.geometry.computeVertexNormals();
}
function resetCamera() {
state.camera.position.set(25, 18, 28);
if (state.controls) {
state.controls.target.set(0, 0, 0);
state.controls.update();
}
}
function initSketch() {
const canvas = els.sketchCanvas;
const ctx = canvas.getContext("2d");
const resize = () => {
const rect = canvas.getBoundingClientRect();
const size = Math.max(220, Math.floor(rect.width));
const scale = Math.min(window.devicePixelRatio || 1, 2);
canvas.width = Math.floor(size * scale);
canvas.height = Math.floor(size * scale);
ctx.setTransform(scale, 0, 0, scale, 0, 0);
ctx.fillStyle = "#4d4d4d";
ctx.fillRect(0, 0, size, size);
};
resize();
window.addEventListener("resize", resize);
let drawing = false;
const draw = (evt) => {
if (!drawing) return;
const rect = canvas.getBoundingClientRect();
const x = evt.clientX - rect.left;
const y = evt.clientY - rect.top;
ctx.fillStyle = state.brush === "high" ? "#ffffff" : state.brush === "low" ? "#000000" : "#4d4d4d";
ctx.beginPath();
ctx.arc(x, y, state.brush === "mid" ? 18 : 15, 0, Math.PI * 2);
ctx.fill();
};
canvas.addEventListener("pointerdown", (evt) => {
drawing = true;
canvas.setPointerCapture(evt.pointerId);
draw(evt);
});
canvas.addEventListener("pointermove", draw);
canvas.addEventListener("pointerup", () => { drawing = false; });
canvas.addEventListener("pointercancel", () => { drawing = false; });
document.querySelectorAll("[data-brush]").forEach((btn) => {
btn.addEventListener("click", () => {
state.brush = btn.dataset.brush;
document.querySelectorAll("[data-brush]").forEach((item) => item.classList.remove("active"));
btn.classList.add("active");
});
});
document.getElementById("clearSketch").addEventListener("click", () => {
const rect = canvas.getBoundingClientRect();
ctx.fillStyle = "#4d4d4d";
ctx.fillRect(0, 0, rect.width, rect.height);
});
}
function setMode(mode) {
state.mode = mode;
document.getElementById("presetMode").classList.toggle("active", mode === "Presets");
document.getElementById("sketchMode").classList.toggle("active", mode === "Custom Sketch");
els.presetGrid.style.display = mode === "Presets" ? "grid" : "none";
els.sketchTools.style.display = mode === "Custom Sketch" ? "flex" : "none";
}
function renderPresets(presets) {
els.presetGrid.innerHTML = "";
presets.forEach((preset) => {
const card = document.createElement("button");
card.className = `preset-card${preset.name === state.selectedPreset ? " active" : ""}`;
card.type = "button";
card.title = preset.description;
card.innerHTML = `<img alt="${preset.name} conditioning preview" src="${preset.preview}"><span>${preset.name}</span>`;
card.addEventListener("click", () => {
state.selectedPreset = preset.name;
document.querySelectorAll(".preset-card").forEach((item) => item.classList.remove("active"));
card.classList.add("active");
els.conditioningPreview.src = preset.preview;
});
els.presetGrid.appendChild(card);
});
const current = presets.find((preset) => preset.name === state.selectedPreset);
if (current) els.conditioningPreview.src = current.preview;
}
function populateModelInfo(info) {
document.getElementById("paperTitle").textContent = `${info.paper_title}. This demo uses the ${info.resolution}.`;
document.getElementById("whyDiffusion").textContent = info.why_diffusion;
document.getElementById("infiniteInfo").textContent = info.infinite;
document.getElementById("paperLink").href = info.paper_url;
document.getElementById("modelLink").href = info.model_url;
const stages = document.getElementById("pipelineStages");
stages.innerHTML = "";
info.pipeline.forEach((stage) => {
const node = document.createElement("div");
node.className = "stage";
node.innerHTML = `<strong>${stage.name}</strong><span>${stage.text}</span>`;
stages.appendChild(node);
});
}
function updateStats(stats) {
els.minStat.textContent = `${stats.min_m} m`;
els.maxStat.textContent = `${stats.max_m} m`;
els.meanStat.textContent = `${stats.mean_m} m`;
els.areaStat.textContent = `${stats.area_km2} km2`;
}
function updateResult(result) {
renderTerrain(result.mesh);
updateStats(result.stats);
els.reliefPreview.src = result.relief_image;
els.elevationPreview.src = result.elevation_image;
els.conditioningPreview.src = result.conditioning_image;
els.downloadLink.href = result.heightmap_url;
els.downloadLink.style.display = "flex";
setStatus(`Generated in ${result.stats.elapsed_s}s`, "ready");
}
async function generateTerrain() {
setError("");
els.generateBtn.disabled = true;
els.loadingOverlay.classList.add("active");
setStatus("Generating", "busy");
try {
const seed = Math.max(0, Math.min(2147483647, Number.parseInt(els.seed.value || "42", 10)));
els.seed.value = String(seed);
const sketch = state.mode === "Custom Sketch" ? els.sketchCanvas.toDataURL("image/png") : "";
const result = await callApi("generate_terrain", [state.selectedPreset, seed, state.mode, sketch]);
if (!result || !result.ok) {
throw new Error("Generation did not return a terrain payload.");
}
updateResult(result);
} catch (err) {
setError(err.message || String(err));
} finally {
els.generateBtn.disabled = false;
els.loadingOverlay.classList.remove("active");
}
}
function bindUi() {
document.getElementById("enterBtn").addEventListener("click", () => {
document.getElementById("landing").classList.add("hidden");
resizeThree();
});
document.getElementById("presetMode").addEventListener("click", () => setMode("Presets"));
document.getElementById("sketchMode").addEventListener("click", () => setMode("Custom Sketch"));
els.generateBtn.addEventListener("click", generateTerrain);
els.exaggeration.addEventListener("input", updateExaggeration);
els.wireToggle.addEventListener("click", () => {
els.wireToggle.classList.toggle("active");
if (state.terrainMesh) state.terrainMesh.material.wireframe = els.wireToggle.classList.contains("active");
});
els.resetCamera.addEventListener("click", resetCamera);
document.querySelectorAll("[data-preview]").forEach((btn) => {
btn.addEventListener("click", () => {
document.querySelectorAll("[data-preview]").forEach((item) => item.classList.remove("active"));
document.querySelectorAll(".preview-img").forEach((item) => item.classList.remove("active"));
btn.classList.add("active");
document.getElementById(`${btn.dataset.preview}Preview`).classList.add("active");
});
});
}
async function boot() {
bindUi();
initThree();
initSketch();
try {
const [presets, info] = await Promise.all([
callApi("get_presets"),
callApi("get_model_info")
]);
renderPresets(presets.presets);
populateModelInfo(info);
setStatus("Ready", "ready");
} catch (err) {
setError(err.message || String(err));
}
}
boot();
</script>
</body>
</html>
"""
app = gr.Server(title="Terrain Diffusion Demo")
demo = app
@app.get("/", response_class=HTMLResponse)
def index():
return HTMLResponse(INDEX_HTML)
@app.get("/healthz")
def healthz():
return JSONResponse({"ok": True, "model": MODEL_ID})
@app.get("/heightmaps/{filename}")
def heightmap(filename: str):
safe_name = Path(filename).name
path = OUTPUT_DIR / safe_name
if not path.exists():
return JSONResponse({"error": "heightmap not found"}, status_code=404)
return FileResponse(path, media_type="image/png", filename=safe_name)
@app.api(name="get_presets", queue=False)
def get_presets() -> dict:
return {"presets": _preset_payload()}
@app.api(name="get_model_info", queue=False)
def get_model_info() -> dict:
return _model_info_payload()
@app.api(name="generate_terrain", concurrency_limit=1, time_limit=180)
def generate_terrain(preset_choice: str, seed: int, input_mode: str, sketch_data_url: str) -> dict:
return _run_generation(preset_choice, seed, input_mode, sketch_data_url)
if __name__ == "__main__":
_get_world()
app.launch(server_name="0.0.0.0")