bhsinghgrid commited on
Commit
4cfbad5
·
verified ·
1 Parent(s): 9a90692

Fix Space build and Gradio schema crash (state + deps)

Browse files
Files changed (3) hide show
  1. __pycache__/app.cpython-311.pyc +0 -0
  2. app.py +22 -11
  3. requirements.txt +1 -3
__pycache__/app.cpython-311.pyc CHANGED
Binary files a/__pycache__/app.cpython-311.pyc and b/__pycache__/app.cpython-311.pyc differ
 
app.py CHANGED
@@ -16,6 +16,7 @@ from inference import _build_tokenizers, _resolve_device, load_model, run_infere
16
  RESULTS_DIR = "generated_results"
17
  DEFAULT_ANALYSIS_OUT = "analysis/outputs"
18
  os.makedirs(RESULTS_DIR, exist_ok=True)
 
19
 
20
 
21
  def discover_checkpoints():
@@ -131,6 +132,7 @@ def load_selected_model(checkpoint_label):
131
  "src_tok": src_tok,
132
  "tgt_tok": tgt_tok,
133
  }
 
134
  model_info = {
135
  "checkpoint": ckpt_path,
136
  "experiment": experiment,
@@ -146,7 +148,19 @@ def load_selected_model(checkpoint_label):
146
  }
147
  status = f"Loaded `{experiment}` on `{device}` (`{cfg['model_type']}`)"
148
  suggested_out = os.path.join("analysis", "outputs_ui", experiment)
149
- return bundle, status, json.dumps(model_info, ensure_ascii=False, indent=2), cfg["inference"]["num_steps"], suggested_out
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
 
152
  def apply_preset(preset_name):
@@ -193,7 +207,7 @@ def save_generation(experiment, record):
193
 
194
 
195
  def generate_from_ui(
196
- model_bundle,
197
  input_text,
198
  temperature,
199
  top_k,
@@ -202,8 +216,7 @@ def generate_from_ui(
202
  num_steps,
203
  clean_output,
204
  ):
205
- if not model_bundle:
206
- raise gr.Error("Load a model first.")
207
  if not input_text.strip():
208
  raise gr.Error("Enter input text first.")
209
 
@@ -276,17 +289,15 @@ def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati
276
  return proc.returncode, log
277
 
278
 
279
- def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
280
- if not model_bundle:
281
- raise gr.Error("Load a model first.")
282
  code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
283
  status = f"Task {task} {'completed' if code == 0 else 'failed'} (exit={code})."
284
  return status, log
285
 
286
 
287
- def run_all_tasks(model_bundle, output_dir, input_text, task4_phase):
288
- if not model_bundle:
289
- raise gr.Error("Load a model first.")
290
  logs = []
291
  failures = 0
292
  for task in ["1", "2", "3", "4", "5"]:
@@ -349,7 +360,7 @@ CUSTOM_CSS = """
349
 
350
 
351
  with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
352
- model_state = gr.State(None)
353
 
354
  gr.Markdown(
355
  """
 
16
  RESULTS_DIR = "generated_results"
17
  DEFAULT_ANALYSIS_OUT = "analysis/outputs"
18
  os.makedirs(RESULTS_DIR, exist_ok=True)
19
+ MODEL_CACHE = {}
20
 
21
 
22
  def discover_checkpoints():
 
132
  "src_tok": src_tok,
133
  "tgt_tok": tgt_tok,
134
  }
135
+ MODEL_CACHE[checkpoint_label] = bundle
136
  model_info = {
137
  "checkpoint": ckpt_path,
138
  "experiment": experiment,
 
148
  }
149
  status = f"Loaded `{experiment}` on `{device}` (`{cfg['model_type']}`)"
150
  suggested_out = os.path.join("analysis", "outputs_ui", experiment)
151
+ return checkpoint_label, status, json.dumps(model_info, ensure_ascii=False, indent=2), cfg["inference"]["num_steps"], suggested_out
152
+
153
+
154
+ def _get_bundle(model_key: str):
155
+ if not model_key:
156
+ raise gr.Error("Load a model first.")
157
+ if model_key not in MODEL_CACHE:
158
+ mapping = checkpoint_map()
159
+ if model_key not in mapping:
160
+ raise gr.Error("Selected checkpoint is no longer available. Refresh and load again.")
161
+ # Lazy reload if Space process restarted.
162
+ load_selected_model(model_key)
163
+ return MODEL_CACHE[model_key]
164
 
165
 
166
  def apply_preset(preset_name):
 
207
 
208
 
209
  def generate_from_ui(
210
+ model_key,
211
  input_text,
212
  temperature,
213
  top_k,
 
216
  num_steps,
217
  clean_output,
218
  ):
219
+ model_bundle = _get_bundle(model_key)
 
220
  if not input_text.strip():
221
  raise gr.Error("Enter input text first.")
222
 
 
289
  return proc.returncode, log
290
 
291
 
292
+ def run_single_task(model_key, task, output_dir, input_text, task4_phase):
293
+ model_bundle = _get_bundle(model_key)
 
294
  code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
295
  status = f"Task {task} {'completed' if code == 0 else 'failed'} (exit={code})."
296
  return status, log
297
 
298
 
299
+ def run_all_tasks(model_key, output_dir, input_text, task4_phase):
300
+ model_bundle = _get_bundle(model_key)
 
301
  logs = []
302
  failures = 0
303
  for task in ["1", "2", "3", "4", "5"]:
 
360
 
361
 
362
  with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
363
+ model_state = gr.State("")
364
 
365
  gr.Markdown(
366
  """
requirements.txt CHANGED
@@ -1,9 +1,7 @@
1
- gradio==5.23.3
2
- gradio_client==1.8.0
3
  torch>=2.2
4
  numpy>=1.24
5
  tqdm>=4.66
6
- huggingface_hub>=0.28,<1.0
7
  tokenizers>=0.15
8
  datasets>=2.20
9
  scikit-learn>=1.4
 
 
 
1
  torch>=2.2
2
  numpy>=1.24
3
  tqdm>=4.66
4
+ huggingface_hub==0.25.2
5
  tokenizers>=0.15
6
  datasets>=2.20
7
  scikit-learn>=1.4