Spaces:
Build error
Build error
| import json | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from pathlib import Path | |
| import lightning as pl | |
| from lightning.pytorch.callbacks import ModelCheckpoint | |
| from lightning.pytorch.loggers import TensorBoardLogger | |
| from torch.utils.data import Dataset, DataLoader | |
| import textwrap | |
| from transformers import ( | |
| AdamW, | |
| T5ForConditionalGeneration, | |
| T5TokenizerFast as T5Tokenizer | |
| ) | |
| from tqdm.auto import tqdm | |
| class NewsSummaryModel(pl.LightningModule): | |
| def __init__(self): | |
| super().__init__() | |
| self.model= T5ForConditionalGeneration.from_pretrained("t5-base", return_dict=True) | |
| def forward(self,input_ids, attention_mask, decoder_attention_mask, labels=None): | |
| output = self.model( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| decoder_attention_mask=decoder_attention_mask | |
| ) | |
| return output.loss, output.logits | |
| def training_step(self, batch, batch_idx): | |
| input_ids=batch["text_input_ids"] | |
| attention_mask=batch["text_attention_mask"] | |
| labels=batch["labels"] | |
| labels_attention_mask=batch["labels_attention_mask"] | |
| loss, outputs = self( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| decoder_attention_mask=labels_attention_mask, | |
| labels=labels | |
| ) | |
| self.log("train_loss", loss, prog_bar=True, logger=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| input_ids=batch["text_input_ids"] | |
| attention_mask=batch["text_attention_mask"] | |
| labels=batch["labels"] | |
| labels_attention_mask=batch["labels_attention_mask"] | |
| loss, outputs = self( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| decoder_attention_mask=labels_attention_mask, | |
| labels=labels | |
| ) | |
| self.log("val_loss", loss, prog_bar=True, logger=True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| input_ids=batch["text_input_ids"] | |
| attention_mask=batch["text_attention_mask"] | |
| labels=batch["labels"] | |
| labels_attention_mask=batch["labels_attention_mask"] | |
| loss, outputs = self( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| decoder_attention_mask=labels_attention_mask, | |
| labels=labels | |
| ) | |
| self.log("test_loss", loss, prog_bar=True, logger=True) | |
| return loss | |
| def configure_optimizers(self): | |
| return AdamW(self.parameters(), lr=0.0001) | |