Spaces:
Build error
Build error
Commit ·
42c8298
1
Parent(s): d655f51
add fix and push ft
Browse files
app.py
CHANGED
|
@@ -47,7 +47,7 @@ def get_ir_evaluator(eval_ds):
|
|
| 47 |
|
| 48 |
|
| 49 |
@spaces.GPU(duration=3600)
|
| 50 |
-
def train(hf_token, dataset_id, model_id, num_epochs, dev):
|
| 51 |
|
| 52 |
ds = load_dataset(dataset_id, split="train", token=hf_token)
|
| 53 |
ds = ds.shuffle(seed=42)
|
|
@@ -110,6 +110,8 @@ def train(hf_token, dataset_id, model_id, num_epochs, dev):
|
|
| 110 |
print(ir_evaluator.primary_metric)
|
| 111 |
print(ft_metrics[ir_evaluator.primary_metric])
|
| 112 |
|
|
|
|
|
|
|
| 113 |
|
| 114 |
metrics = pd.DataFrame([base_metrics, ft_metrics]).T
|
| 115 |
print(metrics)
|
|
@@ -119,5 +121,5 @@ def train(hf_token, dataset_id, model_id, num_epochs, dev):
|
|
| 119 |
## logs to UI
|
| 120 |
# https://github.com/gradio-app/gradio/issues/2362#issuecomment-1424446778
|
| 121 |
|
| 122 |
-
demo = gr.Interface(fn=
|
| 123 |
demo.launch()
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
@spaces.GPU(duration=3600)
|
| 50 |
+
def train(hf_token, dataset_id, model_id, num_epochs, dev=True):
|
| 51 |
|
| 52 |
ds = load_dataset(dataset_id, split="train", token=hf_token)
|
| 53 |
ds = ds.shuffle(seed=42)
|
|
|
|
| 110 |
print(ir_evaluator.primary_metric)
|
| 111 |
print(ft_metrics[ir_evaluator.primary_metric])
|
| 112 |
|
| 113 |
+
if not dev: model.push_to_hub("fine-tuned-sentence-transformer", private=True, token=hf_token)
|
| 114 |
+
|
| 115 |
|
| 116 |
metrics = pd.DataFrame([base_metrics, ft_metrics]).T
|
| 117 |
print(metrics)
|
|
|
|
| 121 |
## logs to UI
|
| 122 |
# https://github.com/gradio-app/gradio/issues/2362#issuecomment-1424446778
|
| 123 |
|
| 124 |
+
demo = gr.Interface(fn=train, inputs=["text", "text", "text", "number", "bool"], outputs=["text"]) # "dataframe"
|
| 125 |
demo.launch()
|