bhsinghgrid commited on
Commit
c0eacc0
·
verified ·
1 Parent(s): 5b8decd

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py CHANGED
@@ -4,9 +4,11 @@ import os
4
  import subprocess
5
  import sys
6
  from datetime import datetime
 
7
 
8
  import gradio as gr
9
  import torch
 
10
 
11
  from config import CONFIG
12
  from inference import _resolve_device, load_model, run_inference, _decode_clean, _decode_with_cleanup
@@ -17,6 +19,24 @@ RESULTS_DIR = "generated_results"
17
  DEFAULT_ANALYSIS_OUT = "analysis/outputs"
18
  os.makedirs(RESULTS_DIR, exist_ok=True)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def discover_checkpoints():
22
  found = []
@@ -35,6 +55,17 @@ def discover_checkpoints():
35
  "root": root,
36
  }
37
  )
 
 
 
 
 
 
 
 
 
 
 
38
  return found
39
 
40
 
@@ -253,6 +284,20 @@ def generate_from_ui(
253
 
254
  def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze"):
255
  os.makedirs(output_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  cmd = [
257
  sys.executable,
258
  "analysis/run_analysis.py",
 
4
  import subprocess
5
  import sys
6
  from datetime import datetime
7
+ from pathlib import Path
8
 
9
  import gradio as gr
10
  import torch
11
+ from huggingface_hub import hf_hub_download
12
 
13
  from config import CONFIG
14
  from inference import _resolve_device, load_model, run_inference, _decode_clean, _decode_with_cleanup
 
19
  DEFAULT_ANALYSIS_OUT = "analysis/outputs"
20
  os.makedirs(RESULTS_DIR, exist_ok=True)
21
 
22
+ HF_DEFAULT_MODEL_REPO = os.environ.get("HF_DEFAULT_MODEL_REPO", "bhsinghgrid/DevaFlow")
23
+ HF_DEFAULT_MODEL_FILE = os.environ.get("HF_DEFAULT_MODEL_FILE", "best_model.pt")
24
+
25
+
26
+ def _download_hf_default_checkpoint():
27
+ try:
28
+ cache_dir = Path(".hf_model_cache")
29
+ cache_dir.mkdir(parents=True, exist_ok=True)
30
+ ckpt = hf_hub_download(
31
+ repo_id=HF_DEFAULT_MODEL_REPO,
32
+ filename=HF_DEFAULT_MODEL_FILE,
33
+ local_dir=str(cache_dir),
34
+ local_dir_use_symlinks=False,
35
+ )
36
+ return ckpt
37
+ except Exception:
38
+ return None
39
+
40
 
41
  def discover_checkpoints():
42
  found = []
 
55
  "root": root,
56
  }
57
  )
58
+ # Space-safe fallback: always expose one downloadable checkpoint option.
59
+ hf_ckpt = _download_hf_default_checkpoint()
60
+ if hf_ckpt and os.path.exists(hf_ckpt):
61
+ found.append(
62
+ {
63
+ "label": f"HF default [{HF_DEFAULT_MODEL_REPO}]",
64
+ "path": hf_ckpt,
65
+ "experiment": "hf_default",
66
+ "root": "hf",
67
+ }
68
+ )
69
  return found
70
 
71
 
 
284
 
285
  def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze"):
286
  os.makedirs(output_dir, exist_ok=True)
287
+ # Space-safe Task4 fallback: if ablation models don't exist, bootstrap them
288
+ # from currently selected checkpoint so Task4 can still execute end-to-end.
289
+ if str(task) == "4" and phase == "analyze":
290
+ for t in (4, 8, 16, 32, 64):
291
+ t_dir = Path("ablation_results") / f"T{t}"
292
+ t_dir.mkdir(parents=True, exist_ok=True)
293
+ dst = t_dir / "best_model.pt"
294
+ if not dst.exists():
295
+ try:
296
+ os.symlink(os.path.abspath(ckpt_path), str(dst))
297
+ except Exception:
298
+ import shutil
299
+ shutil.copy2(ckpt_path, str(dst))
300
+
301
  cmd = [
302
  sys.executable,
303
  "analysis/run_analysis.py",