Spaces:
No application file
No application file
Update finetune.py
Browse files- finetune.py +4 -4
finetune.py
CHANGED
|
@@ -26,7 +26,7 @@ def summarize_text_mt5(texts, model, tokenizer):
|
|
| 26 |
max_length=512, truncation=True,
|
| 27 |
padding=True).to(model.device)
|
| 28 |
summary_ids = model.generate(inputs.input_ids,
|
| 29 |
-
max_length=
|
| 30 |
num_beams=4, length_penalty=2.0,
|
| 31 |
early_stopping=True)
|
| 32 |
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
|
|
@@ -36,7 +36,7 @@ def summarize_text_mbart50(texts, model, tokenizer):
|
|
| 36 |
inputs = tokenizer(texts, return_tensors="pt",
|
| 37 |
max_length=1024, truncation=True,
|
| 38 |
padding=True).to(model.device)
|
| 39 |
-
summary_ids = model.generate(inputs.input_ids, max_length=
|
| 40 |
num_beams=4, length_penalty=2.0,
|
| 41 |
early_stopping=True)
|
| 42 |
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
|
|
@@ -94,10 +94,10 @@ def fine_tune(model_name, finetune_type, model, tokenizer, summarize_text, train
|
|
| 94 |
print("Starting Fine-tuning...")
|
| 95 |
if model_name == "mT5":
|
| 96 |
max_input = 512
|
| 97 |
-
max_output =
|
| 98 |
else:
|
| 99 |
max_input = 1024
|
| 100 |
-
max_output =
|
| 101 |
|
| 102 |
train_dataset = train
|
| 103 |
eval_dataset = val
|
|
|
|
| 26 |
max_length=512, truncation=True,
|
| 27 |
padding=True).to(model.device)
|
| 28 |
summary_ids = model.generate(inputs.input_ids,
|
| 29 |
+
max_length=60,
|
| 30 |
num_beams=4, length_penalty=2.0,
|
| 31 |
early_stopping=True)
|
| 32 |
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
|
|
|
|
| 36 |
inputs = tokenizer(texts, return_tensors="pt",
|
| 37 |
max_length=1024, truncation=True,
|
| 38 |
padding=True).to(model.device)
|
| 39 |
+
summary_ids = model.generate(inputs.input_ids, max_length=60,
|
| 40 |
num_beams=4, length_penalty=2.0,
|
| 41 |
early_stopping=True)
|
| 42 |
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
|
|
|
|
| 94 |
print("Starting Fine-tuning...")
|
| 95 |
if model_name == "mT5":
|
| 96 |
max_input = 512
|
| 97 |
+
max_output = 60
|
| 98 |
else:
|
| 99 |
max_input = 1024
|
| 100 |
+
max_output = 60
|
| 101 |
|
| 102 |
train_dataset = train
|
| 103 |
eval_dataset = val
|