Shreyas Meher commited on
Commit
396346e
·
1 Parent(s): 78221a1

Add model download zips, custom model loading for multiclass

Browse files
Files changed (1) hide show
  1. app.py +70 -37
app.py CHANGED
@@ -419,10 +419,14 @@ def text_classification(text, custom_model=None, custom_tokenizer=None):
419
  return handle_error(e)
420
 
421
 
422
- def multilabel_classification(text):
423
  if not text:
424
  return "Please provide text for classification."
425
  try:
 
 
 
 
426
  inputs = multi_clf_tokenizer(
427
  text, return_tensors='pt', truncation=True, padding=True
428
  ).to(device)
@@ -838,19 +842,20 @@ def predict_finetuned(text, model_state, tokenizer_state, num_labels_state):
838
  return predict_with_model(text, model_state, tokenizer_state)
839
 
840
 
841
- def save_finetuned_model(save_path, model_state, tokenizer_state):
842
- """Save the finetuned model and tokenizer to disk."""
843
  if model_state is None:
844
- return "No model to save. Please train a model first."
845
- if not save_path:
846
- return "Please specify a save directory."
847
  try:
848
- os.makedirs(save_path, exist_ok=True)
849
- model_state.save_pretrained(save_path)
850
- tokenizer_state.save_pretrained(save_path)
851
- return f"Model saved successfully to: {save_path}"
 
 
 
852
  except Exception as e:
853
- return f"Error saving model: {str(e)}"
854
 
855
 
856
  def load_custom_model(path):
@@ -1316,19 +1321,20 @@ def al_submit_and_continue(
1316
  )
1317
 
1318
 
1319
- def al_save_model(save_path, al_model, al_tokenizer):
1320
- """Save the active-learning model to disk."""
1321
  if al_model is None:
1322
- return "No model to save. Run at least one round first."
1323
- if not save_path:
1324
- return "Please specify a save directory."
1325
  try:
1326
- os.makedirs(save_path, exist_ok=True)
1327
- al_model.save_pretrained(save_path)
1328
- al_tokenizer.save_pretrained(save_path)
1329
- return f"Model saved to: {save_path}"
 
 
 
1330
  except Exception as e:
1331
- return f"Error saving model: {str(e)}"
1332
 
1333
 
1334
  def load_example_active_learning():
@@ -1796,8 +1802,12 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
1796
  "Identify multiple event types in text. Each category is scored "
1797
  "independently: **Armed Assault**, **Bombing/Explosion**, "
1798
  "**Kidnapping**, **Other**. Categories above 50% confidence "
1799
- "are highlighted."
1800
  ))
 
 
 
 
1801
  with gr.Row(equal_height=True):
1802
  with gr.Column():
1803
  multi_input = gr.Textbox(
@@ -1816,6 +1826,22 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
1816
  multi_csv_out = gr.File(label="Download Results")
1817
  multi_csv_btn = gr.Button("Process CSV", variant="secondary")
1818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1819
  # ================================================================
1820
  # QUESTION ANSWERING TAB
1821
  # ================================================================
@@ -1983,11 +2009,9 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
1983
  with gr.Column(visible=False) as ft_actions_col:
1984
  with gr.Row(equal_height=True):
1985
  with gr.Column():
1986
- gr.Markdown("**Save model**")
1987
- ft_save_path = gr.Textbox(
1988
- label="Save Directory", value="./finetuned_model",
1989
- )
1990
- ft_save_btn = gr.Button("Save", variant="secondary")
1991
  ft_save_status = gr.Markdown("")
1992
  with gr.Column():
1993
  gr.Markdown("**Batch predictions**")
@@ -2150,12 +2174,10 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
2150
 
2151
  al_chart = gr.Plot(label="Metrics Across Rounds")
2152
 
2153
- gr.Markdown("### Save Model")
2154
  with gr.Row():
2155
- al_save_path = gr.Textbox(
2156
- label="Save Directory", value="./al_model",
2157
- )
2158
- al_save_btn = gr.Button("Save", variant="secondary")
2159
  al_save_status = gr.Markdown("")
2160
 
2161
  # ---- FOOTER ----
@@ -2212,11 +2234,22 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
2212
 
2213
  # Multilabel Classification
2214
  multi_btn.click(
2215
- fn=multilabel_classification, inputs=[multi_input], outputs=[multi_output],
 
 
2216
  )
2217
  multi_csv_btn.click(
2218
  fn=process_csv_multilabel, inputs=[multi_csv_in], outputs=[multi_csv_out],
2219
  )
 
 
 
 
 
 
 
 
 
2220
 
2221
  # Question Answering
2222
  qa_btn.click(
@@ -2264,8 +2297,8 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
2264
  # Save finetuned model
2265
  ft_save_btn.click(
2266
  fn=save_finetuned_model,
2267
- inputs=[ft_save_path, ft_model_state, ft_tokenizer_state],
2268
- outputs=[ft_save_status],
2269
  )
2270
 
2271
  # Batch predictions with finetuned model
@@ -2313,8 +2346,8 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
2313
 
2314
  al_save_btn.click(
2315
  fn=al_save_model,
2316
- inputs=[al_save_path, al_model_state, al_tokenizer_state],
2317
- outputs=[al_save_status],
2318
  )
2319
 
2320
  # Model comparison
 
419
  return handle_error(e)
420
 
421
 
422
+ def multilabel_classification(text, custom_model=None, custom_tokenizer=None):
423
  if not text:
424
  return "Please provide text for classification."
425
  try:
426
+ # Use custom model if loaded
427
+ if custom_model is not None and custom_tokenizer is not None:
428
+ return predict_with_model(text, custom_model, custom_tokenizer)
429
+
430
  inputs = multi_clf_tokenizer(
431
  text, return_tensors='pt', truncation=True, padding=True
432
  ).to(device)
 
842
  return predict_with_model(text, model_state, tokenizer_state)
843
 
844
 
845
+ def save_finetuned_model(model_state, tokenizer_state):
846
+ """Save the finetuned model as a downloadable zip file."""
847
  if model_state is None:
848
+ return None, "No model to save. Please train a model first."
 
 
849
  try:
850
+ save_dir = tempfile.mkdtemp(prefix='conflibert_save_')
851
+ model_state.save_pretrained(save_dir)
852
+ tokenizer_state.save_pretrained(save_dir)
853
+ import shutil
854
+ zip_path = os.path.join(tempfile.gettempdir(), 'finetuned_model')
855
+ shutil.make_archive(zip_path, 'zip', save_dir)
856
+ return zip_path + '.zip', "Model ready for download."
857
  except Exception as e:
858
+ return None, f"Error saving model: {str(e)}"
859
 
860
 
861
  def load_custom_model(path):
 
1321
  )
1322
 
1323
 
1324
+ def al_save_model(al_model, al_tokenizer):
1325
+ """Save the active-learning model as a downloadable zip file."""
1326
  if al_model is None:
1327
+ return None, "No model to save. Run at least one round first."
 
 
1328
  try:
1329
+ save_dir = tempfile.mkdtemp(prefix='conflibert_al_save_')
1330
+ al_model.save_pretrained(save_dir)
1331
+ al_tokenizer.save_pretrained(save_dir)
1332
+ import shutil
1333
+ zip_path = os.path.join(tempfile.gettempdir(), 'al_model')
1334
+ shutil.make_archive(zip_path, 'zip', save_dir)
1335
+ return zip_path + '.zip', "Model ready for download."
1336
  except Exception as e:
1337
+ return None, f"Error saving model: {str(e)}"
1338
 
1339
 
1340
  def load_example_active_learning():
 
1802
  "Identify multiple event types in text. Each category is scored "
1803
  "independently: **Armed Assault**, **Bombing/Explosion**, "
1804
  "**Kidnapping**, **Other**. Categories above 50% confidence "
1805
+ "are highlighted. Load a custom finetuned model below."
1806
  ))
1807
+
1808
+ custom_multi_model = gr.State(None)
1809
+ custom_multi_tokenizer = gr.State(None)
1810
+
1811
  with gr.Row(equal_height=True):
1812
  with gr.Column():
1813
  multi_input = gr.Textbox(
 
1826
  multi_csv_out = gr.File(label="Download Results")
1827
  multi_csv_btn = gr.Button("Process CSV", variant="secondary")
1828
 
1829
+ with gr.Accordion("Load Custom Model", open=False):
1830
+ gr.Markdown(
1831
+ "Load a finetuned multiclass model from a local directory "
1832
+ "to use instead of the default pretrained classifier."
1833
+ )
1834
+ multi_model_path = gr.Textbox(
1835
+ label="Model directory path",
1836
+ placeholder="e.g., ./finetuned_model",
1837
+ )
1838
+ with gr.Row():
1839
+ multi_load_btn = gr.Button("Load Model", variant="secondary")
1840
+ multi_reset_btn = gr.Button(
1841
+ "Reset to Pretrained", variant="secondary",
1842
+ )
1843
+ multi_status = gr.Markdown("")
1844
+
1845
  # ================================================================
1846
  # QUESTION ANSWERING TAB
1847
  # ================================================================
 
2009
  with gr.Column(visible=False) as ft_actions_col:
2010
  with gr.Row(equal_height=True):
2011
  with gr.Column():
2012
+ gr.Markdown("**Download model**")
2013
+ ft_save_btn = gr.Button("Prepare Download", variant="secondary")
2014
+ ft_save_file = gr.File(label="Download Model (.zip)")
 
 
2015
  ft_save_status = gr.Markdown("")
2016
  with gr.Column():
2017
  gr.Markdown("**Batch predictions**")
 
2174
 
2175
  al_chart = gr.Plot(label="Metrics Across Rounds")
2176
 
2177
+ gr.Markdown("### Download Model")
2178
  with gr.Row():
2179
+ al_save_btn = gr.Button("Prepare Download", variant="secondary")
2180
+ al_save_file = gr.File(label="Download Model (.zip)")
 
 
2181
  al_save_status = gr.Markdown("")
2182
 
2183
  # ---- FOOTER ----
 
2234
 
2235
  # Multilabel Classification
2236
  multi_btn.click(
2237
+ fn=multilabel_classification,
2238
+ inputs=[multi_input, custom_multi_model, custom_multi_tokenizer],
2239
+ outputs=[multi_output],
2240
  )
2241
  multi_csv_btn.click(
2242
  fn=process_csv_multilabel, inputs=[multi_csv_in], outputs=[multi_csv_out],
2243
  )
2244
+ multi_load_btn.click(
2245
+ fn=load_custom_model,
2246
+ inputs=[multi_model_path],
2247
+ outputs=[custom_multi_model, custom_multi_tokenizer, multi_status],
2248
+ )
2249
+ multi_reset_btn.click(
2250
+ fn=reset_custom_model,
2251
+ outputs=[custom_multi_model, custom_multi_tokenizer, multi_status],
2252
+ )
2253
 
2254
  # Question Answering
2255
  qa_btn.click(
 
2297
  # Save finetuned model
2298
  ft_save_btn.click(
2299
  fn=save_finetuned_model,
2300
+ inputs=[ft_model_state, ft_tokenizer_state],
2301
+ outputs=[ft_save_file, ft_save_status],
2302
  )
2303
 
2304
  # Batch predictions with finetuned model
 
2346
 
2347
  al_save_btn.click(
2348
  fn=al_save_model,
2349
+ inputs=[al_model_state, al_tokenizer_state],
2350
+ outputs=[al_save_file, al_save_status],
2351
  )
2352
 
2353
  # Model comparison