Spaces:
Running
Running
Shreyas Meher commited on
Commit ·
396346e
1
Parent(s): 78221a1
Add model download zips, custom model loading for multiclass
Browse files
app.py
CHANGED
|
@@ -419,10 +419,14 @@ def text_classification(text, custom_model=None, custom_tokenizer=None):
|
|
| 419 |
return handle_error(e)
|
| 420 |
|
| 421 |
|
| 422 |
-
def multilabel_classification(text):
|
| 423 |
if not text:
|
| 424 |
return "Please provide text for classification."
|
| 425 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
inputs = multi_clf_tokenizer(
|
| 427 |
text, return_tensors='pt', truncation=True, padding=True
|
| 428 |
).to(device)
|
|
@@ -838,19 +842,20 @@ def predict_finetuned(text, model_state, tokenizer_state, num_labels_state):
|
|
| 838 |
return predict_with_model(text, model_state, tokenizer_state)
|
| 839 |
|
| 840 |
|
| 841 |
-
def save_finetuned_model(
|
| 842 |
-
"""Save the finetuned model
|
| 843 |
if model_state is None:
|
| 844 |
-
return "No model to save. Please train a model first."
|
| 845 |
-
if not save_path:
|
| 846 |
-
return "Please specify a save directory."
|
| 847 |
try:
|
| 848 |
-
|
| 849 |
-
model_state.save_pretrained(
|
| 850 |
-
tokenizer_state.save_pretrained(
|
| 851 |
-
|
|
|
|
|
|
|
|
|
|
| 852 |
except Exception as e:
|
| 853 |
-
return f"Error saving model: {str(e)}"
|
| 854 |
|
| 855 |
|
| 856 |
def load_custom_model(path):
|
|
@@ -1316,19 +1321,20 @@ def al_submit_and_continue(
|
|
| 1316 |
)
|
| 1317 |
|
| 1318 |
|
| 1319 |
-
def al_save_model(
|
| 1320 |
-
"""Save the active-learning model
|
| 1321 |
if al_model is None:
|
| 1322 |
-
return "No model to save. Run at least one round first."
|
| 1323 |
-
if not save_path:
|
| 1324 |
-
return "Please specify a save directory."
|
| 1325 |
try:
|
| 1326 |
-
|
| 1327 |
-
al_model.save_pretrained(
|
| 1328 |
-
al_tokenizer.save_pretrained(
|
| 1329 |
-
|
|
|
|
|
|
|
|
|
|
| 1330 |
except Exception as e:
|
| 1331 |
-
return f"Error saving model: {str(e)}"
|
| 1332 |
|
| 1333 |
|
| 1334 |
def load_example_active_learning():
|
|
@@ -1796,8 +1802,12 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 1796 |
"Identify multiple event types in text. Each category is scored "
|
| 1797 |
"independently: **Armed Assault**, **Bombing/Explosion**, "
|
| 1798 |
"**Kidnapping**, **Other**. Categories above 50% confidence "
|
| 1799 |
-
"are highlighted."
|
| 1800 |
))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1801 |
with gr.Row(equal_height=True):
|
| 1802 |
with gr.Column():
|
| 1803 |
multi_input = gr.Textbox(
|
|
@@ -1816,6 +1826,22 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 1816 |
multi_csv_out = gr.File(label="Download Results")
|
| 1817 |
multi_csv_btn = gr.Button("Process CSV", variant="secondary")
|
| 1818 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1819 |
# ================================================================
|
| 1820 |
# QUESTION ANSWERING TAB
|
| 1821 |
# ================================================================
|
|
@@ -1983,11 +2009,9 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 1983 |
with gr.Column(visible=False) as ft_actions_col:
|
| 1984 |
with gr.Row(equal_height=True):
|
| 1985 |
with gr.Column():
|
| 1986 |
-
gr.Markdown("**
|
| 1987 |
-
|
| 1988 |
-
|
| 1989 |
-
)
|
| 1990 |
-
ft_save_btn = gr.Button("Save", variant="secondary")
|
| 1991 |
ft_save_status = gr.Markdown("")
|
| 1992 |
with gr.Column():
|
| 1993 |
gr.Markdown("**Batch predictions**")
|
|
@@ -2150,12 +2174,10 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 2150 |
|
| 2151 |
al_chart = gr.Plot(label="Metrics Across Rounds")
|
| 2152 |
|
| 2153 |
-
gr.Markdown("###
|
| 2154 |
with gr.Row():
|
| 2155 |
-
|
| 2156 |
-
|
| 2157 |
-
)
|
| 2158 |
-
al_save_btn = gr.Button("Save", variant="secondary")
|
| 2159 |
al_save_status = gr.Markdown("")
|
| 2160 |
|
| 2161 |
# ---- FOOTER ----
|
|
@@ -2212,11 +2234,22 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 2212 |
|
| 2213 |
# Multilabel Classification
|
| 2214 |
multi_btn.click(
|
| 2215 |
-
fn=multilabel_classification,
|
|
|
|
|
|
|
| 2216 |
)
|
| 2217 |
multi_csv_btn.click(
|
| 2218 |
fn=process_csv_multilabel, inputs=[multi_csv_in], outputs=[multi_csv_out],
|
| 2219 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2220 |
|
| 2221 |
# Question Answering
|
| 2222 |
qa_btn.click(
|
|
@@ -2264,8 +2297,8 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 2264 |
# Save finetuned model
|
| 2265 |
ft_save_btn.click(
|
| 2266 |
fn=save_finetuned_model,
|
| 2267 |
-
inputs=[
|
| 2268 |
-
outputs=[ft_save_status],
|
| 2269 |
)
|
| 2270 |
|
| 2271 |
# Batch predictions with finetuned model
|
|
@@ -2313,8 +2346,8 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 2313 |
|
| 2314 |
al_save_btn.click(
|
| 2315 |
fn=al_save_model,
|
| 2316 |
-
inputs=[
|
| 2317 |
-
outputs=[al_save_status],
|
| 2318 |
)
|
| 2319 |
|
| 2320 |
# Model comparison
|
|
|
|
| 419 |
return handle_error(e)
|
| 420 |
|
| 421 |
|
| 422 |
+
def multilabel_classification(text, custom_model=None, custom_tokenizer=None):
|
| 423 |
if not text:
|
| 424 |
return "Please provide text for classification."
|
| 425 |
try:
|
| 426 |
+
# Use custom model if loaded
|
| 427 |
+
if custom_model is not None and custom_tokenizer is not None:
|
| 428 |
+
return predict_with_model(text, custom_model, custom_tokenizer)
|
| 429 |
+
|
| 430 |
inputs = multi_clf_tokenizer(
|
| 431 |
text, return_tensors='pt', truncation=True, padding=True
|
| 432 |
).to(device)
|
|
|
|
| 842 |
return predict_with_model(text, model_state, tokenizer_state)
|
| 843 |
|
| 844 |
|
| 845 |
+
def save_finetuned_model(model_state, tokenizer_state):
|
| 846 |
+
"""Save the finetuned model as a downloadable zip file."""
|
| 847 |
if model_state is None:
|
| 848 |
+
return None, "No model to save. Please train a model first."
|
|
|
|
|
|
|
| 849 |
try:
|
| 850 |
+
save_dir = tempfile.mkdtemp(prefix='conflibert_save_')
|
| 851 |
+
model_state.save_pretrained(save_dir)
|
| 852 |
+
tokenizer_state.save_pretrained(save_dir)
|
| 853 |
+
import shutil
|
| 854 |
+
zip_path = os.path.join(tempfile.gettempdir(), 'finetuned_model')
|
| 855 |
+
shutil.make_archive(zip_path, 'zip', save_dir)
|
| 856 |
+
return zip_path + '.zip', "Model ready for download."
|
| 857 |
except Exception as e:
|
| 858 |
+
return None, f"Error saving model: {str(e)}"
|
| 859 |
|
| 860 |
|
| 861 |
def load_custom_model(path):
|
|
|
|
| 1321 |
)
|
| 1322 |
|
| 1323 |
|
| 1324 |
+
def al_save_model(al_model, al_tokenizer):
|
| 1325 |
+
"""Save the active-learning model as a downloadable zip file."""
|
| 1326 |
if al_model is None:
|
| 1327 |
+
return None, "No model to save. Run at least one round first."
|
|
|
|
|
|
|
| 1328 |
try:
|
| 1329 |
+
save_dir = tempfile.mkdtemp(prefix='conflibert_al_save_')
|
| 1330 |
+
al_model.save_pretrained(save_dir)
|
| 1331 |
+
al_tokenizer.save_pretrained(save_dir)
|
| 1332 |
+
import shutil
|
| 1333 |
+
zip_path = os.path.join(tempfile.gettempdir(), 'al_model')
|
| 1334 |
+
shutil.make_archive(zip_path, 'zip', save_dir)
|
| 1335 |
+
return zip_path + '.zip', "Model ready for download."
|
| 1336 |
except Exception as e:
|
| 1337 |
+
return None, f"Error saving model: {str(e)}"
|
| 1338 |
|
| 1339 |
|
| 1340 |
def load_example_active_learning():
|
|
|
|
| 1802 |
"Identify multiple event types in text. Each category is scored "
|
| 1803 |
"independently: **Armed Assault**, **Bombing/Explosion**, "
|
| 1804 |
"**Kidnapping**, **Other**. Categories above 50% confidence "
|
| 1805 |
+
"are highlighted. Load a custom finetuned model below."
|
| 1806 |
))
|
| 1807 |
+
|
| 1808 |
+
custom_multi_model = gr.State(None)
|
| 1809 |
+
custom_multi_tokenizer = gr.State(None)
|
| 1810 |
+
|
| 1811 |
with gr.Row(equal_height=True):
|
| 1812 |
with gr.Column():
|
| 1813 |
multi_input = gr.Textbox(
|
|
|
|
| 1826 |
multi_csv_out = gr.File(label="Download Results")
|
| 1827 |
multi_csv_btn = gr.Button("Process CSV", variant="secondary")
|
| 1828 |
|
| 1829 |
+
with gr.Accordion("Load Custom Model", open=False):
|
| 1830 |
+
gr.Markdown(
|
| 1831 |
+
"Load a finetuned multiclass model from a local directory "
|
| 1832 |
+
"to use instead of the default pretrained classifier."
|
| 1833 |
+
)
|
| 1834 |
+
multi_model_path = gr.Textbox(
|
| 1835 |
+
label="Model directory path",
|
| 1836 |
+
placeholder="e.g., ./finetuned_model",
|
| 1837 |
+
)
|
| 1838 |
+
with gr.Row():
|
| 1839 |
+
multi_load_btn = gr.Button("Load Model", variant="secondary")
|
| 1840 |
+
multi_reset_btn = gr.Button(
|
| 1841 |
+
"Reset to Pretrained", variant="secondary",
|
| 1842 |
+
)
|
| 1843 |
+
multi_status = gr.Markdown("")
|
| 1844 |
+
|
| 1845 |
# ================================================================
|
| 1846 |
# QUESTION ANSWERING TAB
|
| 1847 |
# ================================================================
|
|
|
|
| 2009 |
with gr.Column(visible=False) as ft_actions_col:
|
| 2010 |
with gr.Row(equal_height=True):
|
| 2011 |
with gr.Column():
|
| 2012 |
+
gr.Markdown("**Download model**")
|
| 2013 |
+
ft_save_btn = gr.Button("Prepare Download", variant="secondary")
|
| 2014 |
+
ft_save_file = gr.File(label="Download Model (.zip)")
|
|
|
|
|
|
|
| 2015 |
ft_save_status = gr.Markdown("")
|
| 2016 |
with gr.Column():
|
| 2017 |
gr.Markdown("**Batch predictions**")
|
|
|
|
| 2174 |
|
| 2175 |
al_chart = gr.Plot(label="Metrics Across Rounds")
|
| 2176 |
|
| 2177 |
+
gr.Markdown("### Download Model")
|
| 2178 |
with gr.Row():
|
| 2179 |
+
al_save_btn = gr.Button("Prepare Download", variant="secondary")
|
| 2180 |
+
al_save_file = gr.File(label="Download Model (.zip)")
|
|
|
|
|
|
|
| 2181 |
al_save_status = gr.Markdown("")
|
| 2182 |
|
| 2183 |
# ---- FOOTER ----
|
|
|
|
| 2234 |
|
| 2235 |
# Multilabel Classification
|
| 2236 |
multi_btn.click(
|
| 2237 |
+
fn=multilabel_classification,
|
| 2238 |
+
inputs=[multi_input, custom_multi_model, custom_multi_tokenizer],
|
| 2239 |
+
outputs=[multi_output],
|
| 2240 |
)
|
| 2241 |
multi_csv_btn.click(
|
| 2242 |
fn=process_csv_multilabel, inputs=[multi_csv_in], outputs=[multi_csv_out],
|
| 2243 |
)
|
| 2244 |
+
multi_load_btn.click(
|
| 2245 |
+
fn=load_custom_model,
|
| 2246 |
+
inputs=[multi_model_path],
|
| 2247 |
+
outputs=[custom_multi_model, custom_multi_tokenizer, multi_status],
|
| 2248 |
+
)
|
| 2249 |
+
multi_reset_btn.click(
|
| 2250 |
+
fn=reset_custom_model,
|
| 2251 |
+
outputs=[custom_multi_model, custom_multi_tokenizer, multi_status],
|
| 2252 |
+
)
|
| 2253 |
|
| 2254 |
# Question Answering
|
| 2255 |
qa_btn.click(
|
|
|
|
| 2297 |
# Save finetuned model
|
| 2298 |
ft_save_btn.click(
|
| 2299 |
fn=save_finetuned_model,
|
| 2300 |
+
inputs=[ft_model_state, ft_tokenizer_state],
|
| 2301 |
+
outputs=[ft_save_file, ft_save_status],
|
| 2302 |
)
|
| 2303 |
|
| 2304 |
# Batch predictions with finetuned model
|
|
|
|
| 2346 |
|
| 2347 |
al_save_btn.click(
|
| 2348 |
fn=al_save_model,
|
| 2349 |
+
inputs=[al_model_state, al_tokenizer_state],
|
| 2350 |
+
outputs=[al_save_file, al_save_status],
|
| 2351 |
)
|
| 2352 |
|
| 2353 |
# Model comparison
|