|
|
import argparse
|
|
|
import warnings
|
|
|
from collections import OrderedDict
|
|
|
from rouge import Rouge
|
|
|
import torch
|
|
|
from torch.optim import AdamW
|
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
from torch.utils.data import DataLoader
|
|
|
from datasets import load_dataset
|
|
|
import flwr as fl
|
|
|
|
|
|
|
|
|
from huggingface_hub import notebook_login
|
|
|
|
|
|
notebook_login()
|
|
|
|
|
|
|
|
|
|
|
|
def load_data(node_id):
|
|
|
"""Load dataset (training and eval)"""
|
|
|
dataset = load_dataset("lighteval/legal_summarization", "BillSum")
|
|
|
full_train_dataset = dataset["train"]
|
|
|
eval_datasetx = dataset["test"]
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
|
|
|
|
|
|
|
|
train_dataset_size = len(full_train_dataset)
|
|
|
train_dataset_1 = full_train_dataset.select(range(0, train_dataset_size // 100))
|
|
|
train_dataset_2 = full_train_dataset.select(range(train_dataset_size // 2, train_dataset_size))
|
|
|
|
|
|
eval_dataset = eval_datasetx.select(range(0, 100))
|
|
|
|
|
|
|
|
|
train_dataset = train_dataset_1
|
|
|
|
|
|
train_dataset = train_dataset.map(
|
|
|
lambda x: tokenizer.prepare_seq2seq_batch(x["article"], x["summary"]),
|
|
|
batched=True,
|
|
|
)
|
|
|
|
|
|
eval_dataset = eval_dataset.map(
|
|
|
lambda x: tokenizer.prepare_seq2seq_batch(x["article"], x["summary"]),
|
|
|
batched=True,
|
|
|
)
|
|
|
|
|
|
trainloader = DataLoader(train_dataset, batch_size=4, collate_fn=lambda data: collate_fn(data, tokenizer))
|
|
|
evalloader = DataLoader(eval_dataset, batch_size=4, collate_fn=lambda data: collate_fn(data, tokenizer))
|
|
|
|
|
|
return trainloader, evalloader, eval_dataset
|
|
|
|
|
|
|
|
|
def collate_fn(data, tokenizer):
|
|
|
"""Collate function to convert data into tensors"""
|
|
|
|
|
|
tokenized_articles = []
|
|
|
tokenized_summaries = []
|
|
|
|
|
|
|
|
|
for item in data:
|
|
|
|
|
|
tokenized_item = tokenizer(item["article"], item["summary"], truncation=True, padding=True, return_tensors="pt")
|
|
|
|
|
|
|
|
|
tokenized_articles.append(tokenized_item["input_ids"])
|
|
|
|
|
|
|
|
|
if "labels" in tokenized_item and "labels" in tokenized_item:
|
|
|
|
|
|
tokenized_summaries.append(tokenized_item["labels"])
|
|
|
else:
|
|
|
|
|
|
|
|
|
tokenized_summaries.append(tokenized_item["input_ids"])
|
|
|
|
|
|
|
|
|
tokenized_articles = torch.stack(tokenized_articles).squeeze(dim=1)
|
|
|
tokenized_summaries = torch.stack(tokenized_summaries).squeeze(dim=1)
|
|
|
|
|
|
return {"input_ids": tokenized_articles, "labels": tokenized_summaries}
|
|
|
|
|
|
|
|
|
def train(net, trainloader, epochs):
|
|
|
optimizer = AdamW(net.parameters(), lr=5e-5)
|
|
|
net.train()
|
|
|
total_batches = len(trainloader)
|
|
|
print("Training started...")
|
|
|
for i, batch in enumerate(trainloader, start=1):
|
|
|
inputs = {k: v.to(torch.device("cuda")) for k, v in batch.items()}
|
|
|
labels = inputs.pop("labels", None)
|
|
|
outputs = net(**inputs, labels=labels) if labels is not None else net(**inputs)
|
|
|
loss = outputs.loss
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
print(f"\rBatch {i}/{total_batches} - Loss: {loss.item():.4f}", end="", flush=True)
|
|
|
print("\nTraining finished.")
|
|
|
|
|
|
return net.state_dict()
|
|
|
|
|
|
|
|
|
def calculate_rouge(net, eval_dataset, tokenizer):
|
|
|
rouge = Rouge()
|
|
|
references = [example["summary"] for example in eval_dataset]
|
|
|
|
|
|
generated_summaries = []
|
|
|
for example in eval_dataset:
|
|
|
input_ids = tokenizer(example["article"], truncation=True, padding=True, return_tensors="pt")["input_ids"]
|
|
|
outputs = net.generate(input_ids.to("cuda"))
|
|
|
generated_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
generated_summaries.append(generated_summary)
|
|
|
|
|
|
scores = rouge.get_scores(generated_summaries, references)
|
|
|
rouge_1 = scores[0]["rouge-1"]["f"]
|
|
|
rouge_2 = scores[0]["rouge-2"]["f"]
|
|
|
rouge_l = scores[0]["rouge-l"]["f"]
|
|
|
|
|
|
return rouge_1, rouge_2, rouge_l
|
|
|
|
|
|
def main(node_id):
|
|
|
net = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to("cuda")
|
|
|
|
|
|
trainloader, _, eval_dataset = load_data(node_id)
|
|
|
|
|
|
|
|
|
class PlaceholderClient(fl.client.NumPyClient):
|
|
|
def get_parameters(self, config):
|
|
|
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
|
|
|
|
|
def set_parameters(self, parameters):
|
|
|
params_dict = zip(net.state_dict().keys(), parameters)
|
|
|
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
|
|
|
net.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
def fit(self, parameters, config):
|
|
|
self.set_parameters(parameters)
|
|
|
print("Training Started...")
|
|
|
final_state_dict = train(net, trainloader, epochs=1)
|
|
|
print("Training Finished.")
|
|
|
return self.get_parameters(config={}), len(trainloader), {}
|
|
|
|
|
|
def evaluate(self, parameters, config):
|
|
|
self.set_parameters(parameters)
|
|
|
tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
|
|
rouge_1, rouge_2, rouge_l = calculate_rouge(net, eval_dataset, tokenizer)
|
|
|
print(f"ROUGE-1 Score: {rouge_1:.4f}")
|
|
|
print(f"ROUGE-2 Score: {rouge_2:.4f}")
|
|
|
print(f"ROUGE-L Score: {rouge_l:.4f}")
|
|
|
|
|
|
return 0.0, len(eval_dataset), {
|
|
|
"rouge-1": float(rouge_1),
|
|
|
"rouge-2": float(rouge_2),
|
|
|
"rouge-l": float(rouge_l),
|
|
|
}
|
|
|
|
|
|
|
|
|
fl.client.start_client(
|
|
|
server_address="127.0.0.1:8089", client=PlaceholderClient().to_client()
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="Flower")
|
|
|
parser.add_argument(
|
|
|
"--node-id",
|
|
|
choices=list(range(3)),
|
|
|
required=True,
|
|
|
type=int,
|
|
|
help="Partition of the dataset divided into 1,000 iid partitions created "
|
|
|
"artificially.",
|
|
|
)
|
|
|
node_id = parser.parse_args().node_id
|
|
|
main(node_id)
|
|
|
|