import argparse import warnings from collections import OrderedDict from rouge import Rouge import torch from torch.optim import AdamW from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from torch.utils.data import DataLoader from datasets import load_dataset import flwr as fl from huggingface_hub import notebook_login notebook_login() def load_data(node_id): """Load dataset (training and eval)""" dataset = load_dataset("lighteval/legal_summarization", "BillSum") full_train_dataset = dataset["train"] eval_datasetx = dataset["test"] tokenizer = AutoTokenizer.from_pretrained("t5-small") # Split the full training dataset into two halves train_dataset_size = len(full_train_dataset) train_dataset_1 = full_train_dataset.select(range(0, train_dataset_size // 100)) train_dataset_2 = full_train_dataset.select(range(train_dataset_size // 2, train_dataset_size)) eval_dataset = eval_datasetx.select(range(0, 100)) # Choose one half as the training data train_dataset = train_dataset_1 train_dataset = train_dataset.map( lambda x: tokenizer.prepare_seq2seq_batch(x["article"], x["summary"]), batched=True, ) eval_dataset = eval_dataset.map( lambda x: tokenizer.prepare_seq2seq_batch(x["article"], x["summary"]), batched=True, ) trainloader = DataLoader(train_dataset, batch_size=4, collate_fn=lambda data: collate_fn(data, tokenizer)) evalloader = DataLoader(eval_dataset, batch_size=4, collate_fn=lambda data: collate_fn(data, tokenizer)) return trainloader, evalloader, eval_dataset def collate_fn(data, tokenizer): """Collate function to convert data into tensors""" # Initialize lists to store tokenized articles and summaries tokenized_articles = [] tokenized_summaries = [] # Iterate over each dictionary in the list for item in data: # Tokenize the article and summary tokenized_item = tokenizer(item["article"], item["summary"], truncation=True, padding=True, return_tensors="pt") # Append tokenized article to the list tokenized_articles.append(tokenized_item["input_ids"]) # Check if "labels" key is present in the tokenized item if "labels" in tokenized_item and "labels" in tokenized_item: # If "labels" key is present, append tokenized summary to the list tokenized_summaries.append(tokenized_item["labels"]) else: # If "labels" key is not present, use "input_ids" as a placeholder for the summary # You may need to adjust this logic based on the tokenizer's behavior tokenized_summaries.append(tokenized_item["input_ids"]) # Convert lists to tensors tokenized_articles = torch.stack(tokenized_articles).squeeze(dim=1) # Remove singleton dimension tokenized_summaries = torch.stack(tokenized_summaries).squeeze(dim=1) # Remove singleton dimension return {"input_ids": tokenized_articles, "labels": tokenized_summaries} def train(net, trainloader, epochs): optimizer = AdamW(net.parameters(), lr=5e-5) net.train() total_batches = len(trainloader) print("Training started...") for i, batch in enumerate(trainloader, start=1): inputs = {k: v.to(torch.device("cuda")) for k, v in batch.items()} # Move all tensors to GPU labels = inputs.pop("labels", None) # Remove labels from inputs outputs = net(**inputs, labels=labels) if labels is not None else net(**inputs) loss = outputs.loss loss.backward() optimizer.step() optimizer.zero_grad() # Print progress within the single epoch print(f"\rBatch {i}/{total_batches} - Loss: {loss.item():.4f}", end="", flush=True) print("\nTraining finished.") return net.state_dict() def calculate_rouge(net, eval_dataset, tokenizer): rouge = Rouge() references = [example["summary"] for example in eval_dataset] generated_summaries = [] for example in eval_dataset: input_ids = tokenizer(example["article"], truncation=True, padding=True, return_tensors="pt")["input_ids"] outputs = net.generate(input_ids.to("cuda")) generated_summary = tokenizer.decode(outputs[0], skip_special_tokens=True) generated_summaries.append(generated_summary) scores = rouge.get_scores(generated_summaries, references) rouge_1 = scores[0]["rouge-1"]["f"] rouge_2 = scores[0]["rouge-2"]["f"] rouge_l = scores[0]["rouge-l"]["f"] return rouge_1, rouge_2, rouge_l def main(node_id): net = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to("cuda") trainloader, _, eval_dataset = load_data(node_id) # Flower client class PlaceholderClient(fl.client.NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in net.state_dict().items()] def set_parameters(self, parameters): params_dict = zip(net.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) net.load_state_dict(state_dict, strict=True) def fit(self, parameters, config): self.set_parameters(parameters) print("Training Started...") final_state_dict = train(net, trainloader, epochs=1) print("Training Finished.") return self.get_parameters(config={}), len(trainloader), {} def evaluate(self, parameters, config): self.set_parameters(parameters) tokenizer = AutoTokenizer.from_pretrained("t5-small") rouge_1, rouge_2, rouge_l = calculate_rouge(net, eval_dataset, tokenizer) print(f"ROUGE-1 Score: {rouge_1:.4f}") print(f"ROUGE-2 Score: {rouge_2:.4f}") print(f"ROUGE-L Score: {rouge_l:.4f}") # Replace 0.0 with a tuple or list of three elements return 0.0, len(eval_dataset), { "rouge-1": float(rouge_1), "rouge-2": float(rouge_2), "rouge-l": float(rouge_l), } # Start client fl.client.start_client( server_address="127.0.0.1:8089", client=PlaceholderClient().to_client() ) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Flower") parser.add_argument( "--node-id", choices=list(range(3)), required=True, type=int, help="Partition of the dataset divided into 1,000 iid partitions created " "artificially.", ) node_id = parser.parse_args().node_id main(node_id)