bhsinghgrid commited on
Commit
a7c0255
·
verified ·
1 Parent(s): 9d76bba

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +6 -91
app.py CHANGED
@@ -135,39 +135,9 @@ HF_MODEL_REPOS = [
135
  ]
136
  HF_DEFAULT_MODEL_TYPE = os.environ.get("HF_DEFAULT_MODEL_TYPE", "d3pm_cross_attention")
137
  HF_DEFAULT_INCLUDE_NEG = os.environ.get("HF_DEFAULT_INCLUDE_NEG", "false")
138
- HF_DEFAULT_TYPE = os.environ.get("HF_DEFAULT_TYPE", "d3pm_encoder_decoder")
139
- HF_DEFAULT_NEG = os.environ.get("HF_DEFAULT_NEG", "false")
140
  HF_DEFAULT_NUM_STEPS = os.environ.get("HF_DEFAULT_NUM_STEPS")
141
  HF_DEFAULT_MODEL_SETTINGS_FILE = os.environ.get("HF_DEFAULT_MODEL_SETTINGS_FILE", "model_settings.json")
142
 
143
- # import os
144
- # # from huggingface_hub import hf_hub_download
145
- #
146
- # HF_CHECKPOINT_REPO = os.environ.get("HF_CHECKPOINT_REPO", "bhsinghgrid/devflow2")
147
- # HF_CHECKPOINT_FILE = os.environ.get("HF_CHECKPOINT_FILE", "best_model.pt")
148
-
149
- checkpoint_path = hf_hub_download(
150
- repo_id=HF_CHECKPOINT_REPO,
151
- filename=HF_CHECKPOINT_FILE,
152
- repo_type="model",
153
- )
154
-
155
-
156
-
157
- def _download_hf_default_checkpoint():
158
- try:
159
- cache_dir = Path(".hf_model_cache")
160
- cache_dir.mkdir(parents=True, exist_ok=True)
161
- ckpt = hf_hub_download(
162
- repo_id=HF_DEFAULT_MODEL_REPO,
163
- filename=HF_DEFAULT_MODEL_FILE,
164
- local_dir=str(cache_dir),
165
- local_dir_use_symlinks=False,
166
- )
167
- return ckpt
168
- except Exception:
169
- return None
170
-
171
 
172
  def _download_hf_model_settings():
173
  try:
@@ -177,7 +147,6 @@ def _download_hf_model_settings():
177
  repo_id=HF_DEFAULT_MODEL_REPO,
178
  filename=HF_DEFAULT_MODEL_SETTINGS_FILE,
179
  local_dir=str(cache_dir),
180
- local_dir_use_symlinks=False,
181
  )
182
  with open(settings_path, "r", encoding="utf-8") as f:
183
  data = json.load(f)
@@ -203,7 +172,6 @@ def _download_hf_checkpoint(repo_id: str, filename: str = "best_model.pt"):
203
  repo_id=repo_id,
204
  filename=filename,
205
  local_dir=str(cache_dir),
206
- local_dir_use_symlinks=False,
207
  )
208
  except Exception:
209
  return None
@@ -216,7 +184,6 @@ def _download_hf_settings_for_repo(repo_id: str):
216
  repo_id=repo_id,
217
  filename=HF_DEFAULT_MODEL_SETTINGS_FILE,
218
  local_dir=str(cache_dir),
219
- local_dir_use_symlinks=False,
220
  )
221
  with open(settings_path, "r", encoding="utf-8") as f:
222
  data = json.load(f)
@@ -308,7 +275,11 @@ def default_checkpoint_label():
308
  if not cps:
309
  return None
310
  for item in cps:
311
- if item["path"].endswith("ablation_results/T4/best_model.pt"):
 
 
 
 
312
  return item["label"]
313
  return cps[0]["label"]
314
 
@@ -458,26 +429,6 @@ def apply_preset(preset_name):
458
  return presets.get(preset_name, presets["Balanced"])
459
 
460
 
461
- def clean_generated_text(text: str, max_consecutive: int = 2) -> str:
462
- text = " ".join(text.split())
463
- if not text:
464
- return text
465
- tokens = text.split()
466
- cleaned = []
467
- prev = None
468
- run = 0
469
- for tok in tokens:
470
- if tok == prev:
471
- run += 1
472
- else:
473
- prev = tok
474
- run = 1
475
- if run <= max_consecutive:
476
- cleaned.append(tok)
477
- out = " ".join(cleaned).replace(" ।", "।").replace(" ॥", "॥")
478
- return " ".join(out.split())
479
-
480
-
481
  def save_generation(experiment, record):
482
  ts = datetime.now().strftime("%Y%m%d")
483
  path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
@@ -1027,36 +978,6 @@ def run_single_task(model_bundle, task, output_dir, input_text, task4_phase, tas
1027
  return status, log, task_states, flow
1028
 
1029
 
1030
- def run_all_tasks(model_bundle, output_dir, input_text, task4_phase, task5_cfg, quick_mode):
1031
- if not model_bundle:
1032
- raise gr.Error("Load a model first.")
1033
- logs = []
1034
- failures = 0
1035
- used_bundled_any = False
1036
- for task in ["1", "2", "3", "4", "5"]:
1037
- if quick_mode:
1038
- code, log, used_bundled = _run_quick_task(task, model_bundle, input_text, task5_cfg)
1039
- else:
1040
- code, log, used_bundled = _run_analysis_cmd(
1041
- task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase, task5_cfg.get("samples", 50)
1042
- )
1043
- logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
1044
- used_bundled_any = used_bundled_any or used_bundled
1045
- if code != 0:
1046
- failures += 1
1047
- if failures or used_bundled_any:
1048
- _bundle_task_outputs(model_bundle, output_dir)
1049
- if failures:
1050
- logs.append(f"\n\n--- Live input summary ---\n{_live_input_summary(model_bundle, input_text)}")
1051
- if failures:
1052
- status = f"Run-all finished with {failures} fallback task(s)."
1053
- elif used_bundled_any:
1054
- status = "Run-all loaded from bundled analysis outputs."
1055
- else:
1056
- status = "All 5 tasks completed."
1057
- return status, "".join(logs)
1058
-
1059
-
1060
  def _read_text(path):
1061
  if not os.path.exists(path):
1062
  return "Not found."
@@ -1211,7 +1132,7 @@ CUSTOM_CSS = """
1211
  """
1212
 
1213
 
1214
- with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
1215
  model_state = gr.State(None)
1216
  bg_job_state = gr.State("")
1217
 
@@ -1342,12 +1263,6 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
1342
  msg = f"Found {len(choices)} checkpoint(s)." if choices else "No checkpoints found."
1343
  return gr.Dropdown(choices=choices, value=value), msg
1344
 
1345
- def auto_load_default():
1346
- choices = list(checkpoint_map().keys())
1347
- if not choices:
1348
- return None, "No checkpoints found.", {}, 64, DEFAULT_ANALYSIS_OUT
1349
- return load_selected_model(default_checkpoint_label())
1350
-
1351
  refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown, load_status])
1352
  load_btn.click(
1353
  fn=load_selected_model_with_outputs,
 
135
  ]
136
  HF_DEFAULT_MODEL_TYPE = os.environ.get("HF_DEFAULT_MODEL_TYPE", "d3pm_cross_attention")
137
  HF_DEFAULT_INCLUDE_NEG = os.environ.get("HF_DEFAULT_INCLUDE_NEG", "false")
 
 
138
  HF_DEFAULT_NUM_STEPS = os.environ.get("HF_DEFAULT_NUM_STEPS")
139
  HF_DEFAULT_MODEL_SETTINGS_FILE = os.environ.get("HF_DEFAULT_MODEL_SETTINGS_FILE", "model_settings.json")
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  def _download_hf_model_settings():
143
  try:
 
147
  repo_id=HF_DEFAULT_MODEL_REPO,
148
  filename=HF_DEFAULT_MODEL_SETTINGS_FILE,
149
  local_dir=str(cache_dir),
 
150
  )
151
  with open(settings_path, "r", encoding="utf-8") as f:
152
  data = json.load(f)
 
172
  repo_id=repo_id,
173
  filename=filename,
174
  local_dir=str(cache_dir),
 
175
  )
176
  except Exception:
177
  return None
 
184
  repo_id=repo_id,
185
  filename=HF_DEFAULT_MODEL_SETTINGS_FILE,
186
  local_dir=str(cache_dir),
 
187
  )
188
  with open(settings_path, "r", encoding="utf-8") as f:
189
  data = json.load(f)
 
275
  if not cps:
276
  return None
277
  for item in cps:
278
+ path = item.get("path")
279
+ if path and path.endswith("ablation_results/T4/best_model.pt"):
280
+ return item["label"]
281
+ for item in cps:
282
+ if item.get("repo_id") == HF_DEFAULT_MODEL_REPO:
283
  return item["label"]
284
  return cps[0]["label"]
285
 
 
429
  return presets.get(preset_name, presets["Balanced"])
430
 
431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  def save_generation(experiment, record):
433
  ts = datetime.now().strftime("%Y%m%d")
434
  path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
 
978
  return status, log, task_states, flow
979
 
980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
981
  def _read_text(path):
982
  if not os.path.exists(path):
983
  return "Not found."
 
1132
  """
1133
 
1134
 
1135
+ with gr.Blocks(title="Sanskrit Diffusion Model", css=CUSTOM_CSS) as demo:
1136
  model_state = gr.State(None)
1137
  bg_job_state = gr.State("")
1138
 
 
1263
  msg = f"Found {len(choices)} checkpoint(s)." if choices else "No checkpoints found."
1264
  return gr.Dropdown(choices=choices, value=value), msg
1265
 
 
 
 
 
 
 
1266
  refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown, load_status])
1267
  load_btn.click(
1268
  fn=load_selected_model_with_outputs,