Efradeca commited on
Commit
426d03d
·
verified ·
1 Parent(s): ac829d2

Upload web/app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. web/app.py +61 -0
web/app.py CHANGED
@@ -64,6 +64,20 @@ NU = cfg["generator"]["num_uv"]
64
  _predict(jnp.zeros(NU * NU * 3))
65
  print(f"Model loaded and JIT-compiled. Grid: {NU}x{NU}")
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  # Static topology (edges, boundary vertices, faces) - sent once
68
  EDGES = np.array(list(mesh.edges())).tolist()
69
  BOUNDARY = sorted(set(mesh.vertices_on_boundary()))
@@ -156,6 +170,53 @@ def predict(req: PredictRequest):
156
  })
157
 
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  # Serve static files
160
  STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
161
  app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
 
64
  _predict(jnp.zeros(NU * NU * 3))
65
  print(f"Model loaded and JIT-compiled. Grid: {NU}x{NU}")
66
 
67
+ # Load VAE model for diversity sampling
68
+ VAE_PATH = os.path.join(ROOT, "data", f"variational_formfinder_variational_{TASK}.eqx")
69
+ vae_model = None
70
+ if os.path.exists(VAE_PATH):
71
+ try:
72
+ vae_cfg_path = os.path.join(ROOT, "scripts", f"variational_{TASK}.yml")
73
+ with open(vae_cfg_path) as f:
74
+ vae_cfg = yaml.load(f, Loader=yaml.FullLoader)
75
+ vae_skeleton = build_neural_model("variational_formfinder", vae_cfg, gen, mk)
76
+ vae_model = load_model(VAE_PATH, vae_skeleton)
77
+ print("VAE model loaded for diversity sampling.")
78
+ except Exception as e:
79
+ print(f"VAE not loaded: {e}")
80
+
81
  # Static topology (edges, boundary vertices, faces) - sent once
82
  EDGES = np.array(list(mesh.edges())).tolist()
83
  BOUNDARY = sorted(set(mesh.vertices_on_boundary()))
 
170
  })
171
 
172
 
173
+ class DiversityRequest(BaseModel):
174
+ c1_z: float = 3.0
175
+ c2_x: float = 0.0
176
+ c2_z: float = 1.5
177
+ c3_y: float = 0.0
178
+ n_samples: int = 5
179
+
180
+
181
+ @app.post("/api/diversity")
182
+ def sample_diversity(req: DiversityRequest):
183
+ """Sample diverse equilibrium solutions from the VAE."""
184
+ if vae_model is None:
185
+ return JSONResponse({"error": "VAE model not available"}, status_code=404)
186
+
187
+ transform = jnp.array([
188
+ [0.0, 0.0, req.c1_z],
189
+ [req.c2_x, 0.0, req.c2_z],
190
+ [0.0, req.c3_y, 0.0],
191
+ [0.0, 0.0, 0.0],
192
+ ])
193
+
194
+ xyz_target = gen.evaluate_points(transform)
195
+ key = jrn.PRNGKey(int(time.time()) % 10000)
196
+ x_hats, qs = vae_model.sample(xyz_target, structure, key, num_samples=req.n_samples)
197
+
198
+ samples = []
199
+ for i in range(req.n_samples):
200
+ pred_np = np.array(x_hats[i]).reshape(-1, 3)
201
+ q_np = np.array(qs[i]).flatten()
202
+ samples.append({
203
+ "predicted": pred_np.tolist(),
204
+ "q": q_np.tolist(),
205
+ })
206
+
207
+ return JSONResponse({
208
+ "samples": samples,
209
+ "n_samples": req.n_samples,
210
+ "has_vae": True,
211
+ })
212
+
213
+
214
+ @app.get("/api/has_vae")
215
+ def has_vae():
216
+ """Check if VAE model is available."""
217
+ return {"has_vae": vae_model is not None}
218
+
219
+
220
  # Serve static files
221
  STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
222
  app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")