Atharva Jayappa
Upload 2 files
c5b734d verified
import argparse
import warnings
from collections import OrderedDict
from rouge import Rouge
import torch
from torch.optim import AdamW
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset
import flwr as fl
from huggingface_hub import notebook_login
notebook_login()
def load_data(node_id):
"""Load dataset (training and eval)"""
dataset = load_dataset("lighteval/legal_summarization", "BillSum")
full_train_dataset = dataset["train"]
eval_datasetx = dataset["test"]
tokenizer = AutoTokenizer.from_pretrained("t5-small")
# Split the full training dataset into two halves
train_dataset_size = len(full_train_dataset)
train_dataset_1 = full_train_dataset.select(range(0, train_dataset_size // 100))
train_dataset_2 = full_train_dataset.select(range(train_dataset_size // 2, train_dataset_size))
eval_dataset = eval_datasetx.select(range(0, 100))
# Choose one half as the training data
train_dataset = train_dataset_1
train_dataset = train_dataset.map(
lambda x: tokenizer.prepare_seq2seq_batch(x["article"], x["summary"]),
batched=True,
)
eval_dataset = eval_dataset.map(
lambda x: tokenizer.prepare_seq2seq_batch(x["article"], x["summary"]),
batched=True,
)
trainloader = DataLoader(train_dataset, batch_size=4, collate_fn=lambda data: collate_fn(data, tokenizer))
evalloader = DataLoader(eval_dataset, batch_size=4, collate_fn=lambda data: collate_fn(data, tokenizer))
return trainloader, evalloader, eval_dataset
def collate_fn(data, tokenizer):
"""Collate function to convert data into tensors"""
# Initialize lists to store tokenized articles and summaries
tokenized_articles = []
tokenized_summaries = []
# Iterate over each dictionary in the list
for item in data:
# Tokenize the article and summary
tokenized_item = tokenizer(item["article"], item["summary"], truncation=True, padding=True, return_tensors="pt")
# Append tokenized article to the list
tokenized_articles.append(tokenized_item["input_ids"])
# Check if "labels" key is present in the tokenized item
if "labels" in tokenized_item and "labels" in tokenized_item:
# If "labels" key is present, append tokenized summary to the list
tokenized_summaries.append(tokenized_item["labels"])
else:
# If "labels" key is not present, use "input_ids" as a placeholder for the summary
# You may need to adjust this logic based on the tokenizer's behavior
tokenized_summaries.append(tokenized_item["input_ids"])
# Convert lists to tensors
tokenized_articles = torch.stack(tokenized_articles).squeeze(dim=1) # Remove singleton dimension
tokenized_summaries = torch.stack(tokenized_summaries).squeeze(dim=1) # Remove singleton dimension
return {"input_ids": tokenized_articles, "labels": tokenized_summaries}
def train(net, trainloader, epochs):
optimizer = AdamW(net.parameters(), lr=5e-5)
net.train()
total_batches = len(trainloader)
print("Training started...")
for i, batch in enumerate(trainloader, start=1):
inputs = {k: v.to(torch.device("cuda")) for k, v in batch.items()} # Move all tensors to GPU
labels = inputs.pop("labels", None) # Remove labels from inputs
outputs = net(**inputs, labels=labels) if labels is not None else net(**inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Print progress within the single epoch
print(f"\rBatch {i}/{total_batches} - Loss: {loss.item():.4f}", end="", flush=True)
print("\nTraining finished.")
return net.state_dict()
def calculate_rouge(net, eval_dataset, tokenizer):
rouge = Rouge()
references = [example["summary"] for example in eval_dataset]
generated_summaries = []
for example in eval_dataset:
input_ids = tokenizer(example["article"], truncation=True, padding=True, return_tensors="pt")["input_ids"]
outputs = net.generate(input_ids.to("cuda"))
generated_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_summaries.append(generated_summary)
scores = rouge.get_scores(generated_summaries, references)
rouge_1 = scores[0]["rouge-1"]["f"]
rouge_2 = scores[0]["rouge-2"]["f"]
rouge_l = scores[0]["rouge-l"]["f"]
return rouge_1, rouge_2, rouge_l
def main(node_id):
net = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to("cuda")
trainloader, _, eval_dataset = load_data(node_id)
# Flower client
class PlaceholderClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
self.set_parameters(parameters)
print("Training Started...")
final_state_dict = train(net, trainloader, epochs=1)
print("Training Finished.")
return self.get_parameters(config={}), len(trainloader), {}
def evaluate(self, parameters, config):
self.set_parameters(parameters)
tokenizer = AutoTokenizer.from_pretrained("t5-small")
rouge_1, rouge_2, rouge_l = calculate_rouge(net, eval_dataset, tokenizer)
print(f"ROUGE-1 Score: {rouge_1:.4f}")
print(f"ROUGE-2 Score: {rouge_2:.4f}")
print(f"ROUGE-L Score: {rouge_l:.4f}")
# Replace 0.0 with a tuple or list of three elements
return 0.0, len(eval_dataset), {
"rouge-1": float(rouge_1),
"rouge-2": float(rouge_2),
"rouge-l": float(rouge_l),
}
# Start client
fl.client.start_client(
server_address="127.0.0.1:8089", client=PlaceholderClient().to_client()
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
choices=list(range(3)),
required=True,
type=int,
help="Partition of the dataset divided into 1,000 iid partitions created "
"artificially.",
)
node_id = parser.parse_args().node_id
main(node_id)