Upload web/app.py with huggingface_hub
Browse files- web/app.py +61 -0
web/app.py
CHANGED
|
@@ -64,6 +64,20 @@ NU = cfg["generator"]["num_uv"]
|
|
| 64 |
_predict(jnp.zeros(NU * NU * 3))
|
| 65 |
print(f"Model loaded and JIT-compiled. Grid: {NU}x{NU}")
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
# Static topology (edges, boundary vertices, faces) - sent once
|
| 68 |
EDGES = np.array(list(mesh.edges())).tolist()
|
| 69 |
BOUNDARY = sorted(set(mesh.vertices_on_boundary()))
|
|
@@ -156,6 +170,53 @@ def predict(req: PredictRequest):
|
|
| 156 |
})
|
| 157 |
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
# Serve static files
|
| 160 |
STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
|
| 161 |
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
|
|
| 64 |
_predict(jnp.zeros(NU * NU * 3))
|
| 65 |
print(f"Model loaded and JIT-compiled. Grid: {NU}x{NU}")
|
| 66 |
|
| 67 |
+
# Load VAE model for diversity sampling
|
| 68 |
+
VAE_PATH = os.path.join(ROOT, "data", f"variational_formfinder_variational_{TASK}.eqx")
|
| 69 |
+
vae_model = None
|
| 70 |
+
if os.path.exists(VAE_PATH):
|
| 71 |
+
try:
|
| 72 |
+
vae_cfg_path = os.path.join(ROOT, "scripts", f"variational_{TASK}.yml")
|
| 73 |
+
with open(vae_cfg_path) as f:
|
| 74 |
+
vae_cfg = yaml.load(f, Loader=yaml.FullLoader)
|
| 75 |
+
vae_skeleton = build_neural_model("variational_formfinder", vae_cfg, gen, mk)
|
| 76 |
+
vae_model = load_model(VAE_PATH, vae_skeleton)
|
| 77 |
+
print("VAE model loaded for diversity sampling.")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"VAE not loaded: {e}")
|
| 80 |
+
|
| 81 |
# Static topology (edges, boundary vertices, faces) - sent once
|
| 82 |
EDGES = np.array(list(mesh.edges())).tolist()
|
| 83 |
BOUNDARY = sorted(set(mesh.vertices_on_boundary()))
|
|
|
|
| 170 |
})
|
| 171 |
|
| 172 |
|
| 173 |
+
class DiversityRequest(BaseModel):
|
| 174 |
+
c1_z: float = 3.0
|
| 175 |
+
c2_x: float = 0.0
|
| 176 |
+
c2_z: float = 1.5
|
| 177 |
+
c3_y: float = 0.0
|
| 178 |
+
n_samples: int = 5
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@app.post("/api/diversity")
|
| 182 |
+
def sample_diversity(req: DiversityRequest):
|
| 183 |
+
"""Sample diverse equilibrium solutions from the VAE."""
|
| 184 |
+
if vae_model is None:
|
| 185 |
+
return JSONResponse({"error": "VAE model not available"}, status_code=404)
|
| 186 |
+
|
| 187 |
+
transform = jnp.array([
|
| 188 |
+
[0.0, 0.0, req.c1_z],
|
| 189 |
+
[req.c2_x, 0.0, req.c2_z],
|
| 190 |
+
[0.0, req.c3_y, 0.0],
|
| 191 |
+
[0.0, 0.0, 0.0],
|
| 192 |
+
])
|
| 193 |
+
|
| 194 |
+
xyz_target = gen.evaluate_points(transform)
|
| 195 |
+
key = jrn.PRNGKey(int(time.time()) % 10000)
|
| 196 |
+
x_hats, qs = vae_model.sample(xyz_target, structure, key, num_samples=req.n_samples)
|
| 197 |
+
|
| 198 |
+
samples = []
|
| 199 |
+
for i in range(req.n_samples):
|
| 200 |
+
pred_np = np.array(x_hats[i]).reshape(-1, 3)
|
| 201 |
+
q_np = np.array(qs[i]).flatten()
|
| 202 |
+
samples.append({
|
| 203 |
+
"predicted": pred_np.tolist(),
|
| 204 |
+
"q": q_np.tolist(),
|
| 205 |
+
})
|
| 206 |
+
|
| 207 |
+
return JSONResponse({
|
| 208 |
+
"samples": samples,
|
| 209 |
+
"n_samples": req.n_samples,
|
| 210 |
+
"has_vae": True,
|
| 211 |
+
})
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@app.get("/api/has_vae")
|
| 215 |
+
def has_vae():
|
| 216 |
+
"""Check if VAE model is available."""
|
| 217 |
+
return {"has_vae": vae_model is not None}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
# Serve static files
|
| 221 |
STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
|
| 222 |
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|