Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -135,39 +135,9 @@ HF_MODEL_REPOS = [
|
|
| 135 |
]
|
| 136 |
HF_DEFAULT_MODEL_TYPE = os.environ.get("HF_DEFAULT_MODEL_TYPE", "d3pm_cross_attention")
|
| 137 |
HF_DEFAULT_INCLUDE_NEG = os.environ.get("HF_DEFAULT_INCLUDE_NEG", "false")
|
| 138 |
-
HF_DEFAULT_TYPE = os.environ.get("HF_DEFAULT_TYPE", "d3pm_encoder_decoder")
|
| 139 |
-
HF_DEFAULT_NEG = os.environ.get("HF_DEFAULT_NEG", "false")
|
| 140 |
HF_DEFAULT_NUM_STEPS = os.environ.get("HF_DEFAULT_NUM_STEPS")
|
| 141 |
HF_DEFAULT_MODEL_SETTINGS_FILE = os.environ.get("HF_DEFAULT_MODEL_SETTINGS_FILE", "model_settings.json")
|
| 142 |
|
| 143 |
-
# import os
|
| 144 |
-
# # from huggingface_hub import hf_hub_download
|
| 145 |
-
#
|
| 146 |
-
# HF_CHECKPOINT_REPO = os.environ.get("HF_CHECKPOINT_REPO", "bhsinghgrid/devflow2")
|
| 147 |
-
# HF_CHECKPOINT_FILE = os.environ.get("HF_CHECKPOINT_FILE", "best_model.pt")
|
| 148 |
-
|
| 149 |
-
checkpoint_path = hf_hub_download(
|
| 150 |
-
repo_id=HF_CHECKPOINT_REPO,
|
| 151 |
-
filename=HF_CHECKPOINT_FILE,
|
| 152 |
-
repo_type="model",
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def _download_hf_default_checkpoint():
|
| 158 |
-
try:
|
| 159 |
-
cache_dir = Path(".hf_model_cache")
|
| 160 |
-
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 161 |
-
ckpt = hf_hub_download(
|
| 162 |
-
repo_id=HF_DEFAULT_MODEL_REPO,
|
| 163 |
-
filename=HF_DEFAULT_MODEL_FILE,
|
| 164 |
-
local_dir=str(cache_dir),
|
| 165 |
-
local_dir_use_symlinks=False,
|
| 166 |
-
)
|
| 167 |
-
return ckpt
|
| 168 |
-
except Exception:
|
| 169 |
-
return None
|
| 170 |
-
|
| 171 |
|
| 172 |
def _download_hf_model_settings():
|
| 173 |
try:
|
|
@@ -177,7 +147,6 @@ def _download_hf_model_settings():
|
|
| 177 |
repo_id=HF_DEFAULT_MODEL_REPO,
|
| 178 |
filename=HF_DEFAULT_MODEL_SETTINGS_FILE,
|
| 179 |
local_dir=str(cache_dir),
|
| 180 |
-
local_dir_use_symlinks=False,
|
| 181 |
)
|
| 182 |
with open(settings_path, "r", encoding="utf-8") as f:
|
| 183 |
data = json.load(f)
|
|
@@ -203,7 +172,6 @@ def _download_hf_checkpoint(repo_id: str, filename: str = "best_model.pt"):
|
|
| 203 |
repo_id=repo_id,
|
| 204 |
filename=filename,
|
| 205 |
local_dir=str(cache_dir),
|
| 206 |
-
local_dir_use_symlinks=False,
|
| 207 |
)
|
| 208 |
except Exception:
|
| 209 |
return None
|
|
@@ -216,7 +184,6 @@ def _download_hf_settings_for_repo(repo_id: str):
|
|
| 216 |
repo_id=repo_id,
|
| 217 |
filename=HF_DEFAULT_MODEL_SETTINGS_FILE,
|
| 218 |
local_dir=str(cache_dir),
|
| 219 |
-
local_dir_use_symlinks=False,
|
| 220 |
)
|
| 221 |
with open(settings_path, "r", encoding="utf-8") as f:
|
| 222 |
data = json.load(f)
|
|
@@ -308,7 +275,11 @@ def default_checkpoint_label():
|
|
| 308 |
if not cps:
|
| 309 |
return None
|
| 310 |
for item in cps:
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
return item["label"]
|
| 313 |
return cps[0]["label"]
|
| 314 |
|
|
@@ -458,26 +429,6 @@ def apply_preset(preset_name):
|
|
| 458 |
return presets.get(preset_name, presets["Balanced"])
|
| 459 |
|
| 460 |
|
| 461 |
-
def clean_generated_text(text: str, max_consecutive: int = 2) -> str:
|
| 462 |
-
text = " ".join(text.split())
|
| 463 |
-
if not text:
|
| 464 |
-
return text
|
| 465 |
-
tokens = text.split()
|
| 466 |
-
cleaned = []
|
| 467 |
-
prev = None
|
| 468 |
-
run = 0
|
| 469 |
-
for tok in tokens:
|
| 470 |
-
if tok == prev:
|
| 471 |
-
run += 1
|
| 472 |
-
else:
|
| 473 |
-
prev = tok
|
| 474 |
-
run = 1
|
| 475 |
-
if run <= max_consecutive:
|
| 476 |
-
cleaned.append(tok)
|
| 477 |
-
out = " ".join(cleaned).replace(" ।", "।").replace(" ॥", "॥")
|
| 478 |
-
return " ".join(out.split())
|
| 479 |
-
|
| 480 |
-
|
| 481 |
def save_generation(experiment, record):
|
| 482 |
ts = datetime.now().strftime("%Y%m%d")
|
| 483 |
path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
|
|
@@ -1027,36 +978,6 @@ def run_single_task(model_bundle, task, output_dir, input_text, task4_phase, tas
|
|
| 1027 |
return status, log, task_states, flow
|
| 1028 |
|
| 1029 |
|
| 1030 |
-
def run_all_tasks(model_bundle, output_dir, input_text, task4_phase, task5_cfg, quick_mode):
|
| 1031 |
-
if not model_bundle:
|
| 1032 |
-
raise gr.Error("Load a model first.")
|
| 1033 |
-
logs = []
|
| 1034 |
-
failures = 0
|
| 1035 |
-
used_bundled_any = False
|
| 1036 |
-
for task in ["1", "2", "3", "4", "5"]:
|
| 1037 |
-
if quick_mode:
|
| 1038 |
-
code, log, used_bundled = _run_quick_task(task, model_bundle, input_text, task5_cfg)
|
| 1039 |
-
else:
|
| 1040 |
-
code, log, used_bundled = _run_analysis_cmd(
|
| 1041 |
-
task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase, task5_cfg.get("samples", 50)
|
| 1042 |
-
)
|
| 1043 |
-
logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
|
| 1044 |
-
used_bundled_any = used_bundled_any or used_bundled
|
| 1045 |
-
if code != 0:
|
| 1046 |
-
failures += 1
|
| 1047 |
-
if failures or used_bundled_any:
|
| 1048 |
-
_bundle_task_outputs(model_bundle, output_dir)
|
| 1049 |
-
if failures:
|
| 1050 |
-
logs.append(f"\n\n--- Live input summary ---\n{_live_input_summary(model_bundle, input_text)}")
|
| 1051 |
-
if failures:
|
| 1052 |
-
status = f"Run-all finished with {failures} fallback task(s)."
|
| 1053 |
-
elif used_bundled_any:
|
| 1054 |
-
status = "Run-all loaded from bundled analysis outputs."
|
| 1055 |
-
else:
|
| 1056 |
-
status = "All 5 tasks completed."
|
| 1057 |
-
return status, "".join(logs)
|
| 1058 |
-
|
| 1059 |
-
|
| 1060 |
def _read_text(path):
|
| 1061 |
if not os.path.exists(path):
|
| 1062 |
return "Not found."
|
|
@@ -1211,7 +1132,7 @@ CUSTOM_CSS = """
|
|
| 1211 |
"""
|
| 1212 |
|
| 1213 |
|
| 1214 |
-
with gr.Blocks(title="Sanskrit Diffusion
|
| 1215 |
model_state = gr.State(None)
|
| 1216 |
bg_job_state = gr.State("")
|
| 1217 |
|
|
@@ -1342,12 +1263,6 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
|
| 1342 |
msg = f"Found {len(choices)} checkpoint(s)." if choices else "No checkpoints found."
|
| 1343 |
return gr.Dropdown(choices=choices, value=value), msg
|
| 1344 |
|
| 1345 |
-
def auto_load_default():
|
| 1346 |
-
choices = list(checkpoint_map().keys())
|
| 1347 |
-
if not choices:
|
| 1348 |
-
return None, "No checkpoints found.", {}, 64, DEFAULT_ANALYSIS_OUT
|
| 1349 |
-
return load_selected_model(default_checkpoint_label())
|
| 1350 |
-
|
| 1351 |
refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown, load_status])
|
| 1352 |
load_btn.click(
|
| 1353 |
fn=load_selected_model_with_outputs,
|
|
|
|
| 135 |
]
|
| 136 |
HF_DEFAULT_MODEL_TYPE = os.environ.get("HF_DEFAULT_MODEL_TYPE", "d3pm_cross_attention")
|
| 137 |
HF_DEFAULT_INCLUDE_NEG = os.environ.get("HF_DEFAULT_INCLUDE_NEG", "false")
|
|
|
|
|
|
|
| 138 |
HF_DEFAULT_NUM_STEPS = os.environ.get("HF_DEFAULT_NUM_STEPS")
|
| 139 |
HF_DEFAULT_MODEL_SETTINGS_FILE = os.environ.get("HF_DEFAULT_MODEL_SETTINGS_FILE", "model_settings.json")
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
def _download_hf_model_settings():
|
| 143 |
try:
|
|
|
|
| 147 |
repo_id=HF_DEFAULT_MODEL_REPO,
|
| 148 |
filename=HF_DEFAULT_MODEL_SETTINGS_FILE,
|
| 149 |
local_dir=str(cache_dir),
|
|
|
|
| 150 |
)
|
| 151 |
with open(settings_path, "r", encoding="utf-8") as f:
|
| 152 |
data = json.load(f)
|
|
|
|
| 172 |
repo_id=repo_id,
|
| 173 |
filename=filename,
|
| 174 |
local_dir=str(cache_dir),
|
|
|
|
| 175 |
)
|
| 176 |
except Exception:
|
| 177 |
return None
|
|
|
|
| 184 |
repo_id=repo_id,
|
| 185 |
filename=HF_DEFAULT_MODEL_SETTINGS_FILE,
|
| 186 |
local_dir=str(cache_dir),
|
|
|
|
| 187 |
)
|
| 188 |
with open(settings_path, "r", encoding="utf-8") as f:
|
| 189 |
data = json.load(f)
|
|
|
|
| 275 |
if not cps:
|
| 276 |
return None
|
| 277 |
for item in cps:
|
| 278 |
+
path = item.get("path")
|
| 279 |
+
if path and path.endswith("ablation_results/T4/best_model.pt"):
|
| 280 |
+
return item["label"]
|
| 281 |
+
for item in cps:
|
| 282 |
+
if item.get("repo_id") == HF_DEFAULT_MODEL_REPO:
|
| 283 |
return item["label"]
|
| 284 |
return cps[0]["label"]
|
| 285 |
|
|
|
|
| 429 |
return presets.get(preset_name, presets["Balanced"])
|
| 430 |
|
| 431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
def save_generation(experiment, record):
|
| 433 |
ts = datetime.now().strftime("%Y%m%d")
|
| 434 |
path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
|
|
|
|
| 978 |
return status, log, task_states, flow
|
| 979 |
|
| 980 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
def _read_text(path):
|
| 982 |
if not os.path.exists(path):
|
| 983 |
return "Not found."
|
|
|
|
| 1132 |
"""
|
| 1133 |
|
| 1134 |
|
| 1135 |
+
with gr.Blocks(title="Sanskrit Diffusion Model", css=CUSTOM_CSS) as demo:
|
| 1136 |
model_state = gr.State(None)
|
| 1137 |
bg_job_state = gr.State("")
|
| 1138 |
|
|
|
|
| 1263 |
msg = f"Found {len(choices)} checkpoint(s)." if choices else "No checkpoints found."
|
| 1264 |
return gr.Dropdown(choices=choices, value=value), msg
|
| 1265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1266 |
refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown, load_status])
|
| 1267 |
load_btn.click(
|
| 1268 |
fn=load_selected_model_with_outputs,
|