Spaces:
Runtime error
Runtime error
Fixes
Browse files- src/models/evaluate_model.py +1 -1
- src/models/model.py +18 -6
- src/models/train_model.py +6 -4
src/models/evaluate_model.py
CHANGED
|
@@ -10,7 +10,7 @@ def evaluate_model():
|
|
| 10 |
test_df = pd.load_csv('../../data/processed/test.csv')
|
| 11 |
model = Summarization()
|
| 12 |
model.load_model()
|
| 13 |
-
results = model.evaluate(test_df=test_df)
|
| 14 |
with dagshub.dagshub_logger() as logger:
|
| 15 |
logger.log_metrics(results)
|
| 16 |
return results
|
|
|
|
| 10 |
test_df = pd.load_csv('../../data/processed/test.csv')
|
| 11 |
model = Summarization()
|
| 12 |
model.load_model()
|
| 13 |
+
results = model.evaluate(test_df=test_df,metrics="rouge")
|
| 14 |
with dagshub.dagshub_logger() as logger:
|
| 15 |
logger.log_metrics(results)
|
| 16 |
return results
|
src/models/model.py
CHANGED
|
@@ -6,7 +6,7 @@ from dagshub.pytorch_lightning import DAGsHubLogger
|
|
| 6 |
from transformers import (
|
| 7 |
AdamW,
|
| 8 |
T5ForConditionalGeneration,
|
| 9 |
-
T5TokenizerFast as T5Tokenizer,
|
| 10 |
)
|
| 11 |
from torch.utils.data import Dataset, DataLoader
|
| 12 |
import pytorch_lightning as pl
|
|
@@ -250,16 +250,28 @@ class Summarization:
|
|
| 250 |
""" initiates Summarization class """
|
| 251 |
pass
|
| 252 |
|
| 253 |
-
def from_pretrained(self, model_name="t5-base") -> None:
|
| 254 |
"""
|
| 255 |
loads T5/MT5 Model model for training/finetuning
|
| 256 |
Args:
|
| 257 |
model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
|
|
|
|
| 258 |
"""
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
def train(
|
| 265 |
self,
|
|
|
|
| 6 |
from transformers import (
|
| 7 |
AdamW,
|
| 8 |
T5ForConditionalGeneration,
|
| 9 |
+
T5TokenizerFast as T5Tokenizer, MT5Tokenizer, MT5ForConditionalGeneration,ByT5Tokenizer,
|
| 10 |
)
|
| 11 |
from torch.utils.data import Dataset, DataLoader
|
| 12 |
import pytorch_lightning as pl
|
|
|
|
| 250 |
""" initiates Summarization class """
|
| 251 |
pass
|
| 252 |
|
| 253 |
+
def from_pretrained(self,model_type = "t5", model_name="t5-base") -> None:
|
| 254 |
"""
|
| 255 |
loads T5/MT5 Model model for training/finetuning
|
| 256 |
Args:
|
| 257 |
model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
|
| 258 |
+
:param model_type:
|
| 259 |
"""
|
| 260 |
+
if model_type == "t5":
|
| 261 |
+
self.tokenizer = T5Tokenizer.from_pretrained(f"{model_name}")
|
| 262 |
+
self.model = T5ForConditionalGeneration.from_pretrained(
|
| 263 |
+
f"{model_name}", return_dict=True
|
| 264 |
+
)
|
| 265 |
+
elif model_type == "mt5":
|
| 266 |
+
self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_name}")
|
| 267 |
+
self.model = MT5ForConditionalGeneration.from_pretrained(
|
| 268 |
+
f"{model_name}", return_dict=True
|
| 269 |
+
)
|
| 270 |
+
elif model_type == "byt5":
|
| 271 |
+
self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_name}")
|
| 272 |
+
self.model = T5ForConditionalGeneration.from_pretrained(
|
| 273 |
+
f"{model_name}", return_dict=True
|
| 274 |
+
)
|
| 275 |
|
| 276 |
def train(
|
| 277 |
self,
|
src/models/train_model.py
CHANGED
|
@@ -1,18 +1,20 @@
|
|
| 1 |
from src.models.model import Summarization
|
| 2 |
import pandas as pd
|
| 3 |
|
|
|
|
| 4 |
def train_model():
|
| 5 |
"""
|
| 6 |
Train the model
|
| 7 |
"""
|
| 8 |
# Load the data
|
| 9 |
-
train_df = pd.
|
| 10 |
-
eval_df = pd.
|
| 11 |
|
| 12 |
model = Summarization()
|
| 13 |
-
model.from_pretrained('t5-base')
|
| 14 |
model.train(train_df=train_df, eval_df=eval_df, batch_size=4, max_epochs=3, use_gpu=True)
|
| 15 |
model.save_model()
|
| 16 |
|
|
|
|
| 17 |
if __name__ == '__main__':
|
| 18 |
-
train_model()
|
|
|
|
| 1 |
from src.models.model import Summarization
|
| 2 |
import pandas as pd
|
| 3 |
|
| 4 |
+
|
| 5 |
def train_model():
|
| 6 |
"""
|
| 7 |
Train the model
|
| 8 |
"""
|
| 9 |
# Load the data
|
| 10 |
+
train_df = pd.read_csv('../../data/processed/train.csv')
|
| 11 |
+
eval_df = pd.read_csv('../../data/processed/validation.csv')
|
| 12 |
|
| 13 |
model = Summarization()
|
| 14 |
+
model.from_pretrained('t5','t5-base')
|
| 15 |
model.train(train_df=train_df, eval_df=eval_df, batch_size=4, max_epochs=3, use_gpu=True)
|
| 16 |
model.save_model()
|
| 17 |
|
| 18 |
+
|
| 19 |
if __name__ == '__main__':
|
| 20 |
+
train_model()
|