File size: 6,806 Bytes
c5b734d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import argparse
import warnings
from collections import OrderedDict
from rouge import Rouge
import torch
from torch.optim import AdamW
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset
import flwr as fl


from huggingface_hub import notebook_login

notebook_login()



def load_data(node_id):
    """Load dataset (training and eval)"""
    dataset = load_dataset("lighteval/legal_summarization", "BillSum")
    full_train_dataset = dataset["train"]
    eval_datasetx = dataset["test"]

    tokenizer = AutoTokenizer.from_pretrained("t5-small")

    # Split the full training dataset into two halves
    train_dataset_size = len(full_train_dataset)
    train_dataset_1 = full_train_dataset.select(range(0, train_dataset_size // 100))
    train_dataset_2 = full_train_dataset.select(range(train_dataset_size // 2, train_dataset_size))

    eval_dataset = eval_datasetx.select(range(0, 100))

    # Choose one half as the training data
    train_dataset = train_dataset_1

    train_dataset = train_dataset.map(
        lambda x: tokenizer.prepare_seq2seq_batch(x["article"], x["summary"]),
        batched=True,
    )

    eval_dataset = eval_dataset.map(
        lambda x: tokenizer.prepare_seq2seq_batch(x["article"], x["summary"]),
        batched=True,
    )

    trainloader = DataLoader(train_dataset, batch_size=4, collate_fn=lambda data: collate_fn(data, tokenizer))
    evalloader = DataLoader(eval_dataset, batch_size=4, collate_fn=lambda data: collate_fn(data, tokenizer))

    return trainloader, evalloader, eval_dataset


def collate_fn(data, tokenizer):
    """Collate function to convert data into tensors"""
    # Initialize lists to store tokenized articles and summaries
    tokenized_articles = []
    tokenized_summaries = []

    # Iterate over each dictionary in the list
    for item in data:
        # Tokenize the article and summary
        tokenized_item = tokenizer(item["article"], item["summary"], truncation=True, padding=True, return_tensors="pt")

        # Append tokenized article to the list
        tokenized_articles.append(tokenized_item["input_ids"])

        # Check if "labels" key is present in the tokenized item
        if "labels" in tokenized_item and "labels" in tokenized_item:
            # If "labels" key is present, append tokenized summary to the list
            tokenized_summaries.append(tokenized_item["labels"])
        else:
            # If "labels" key is not present, use "input_ids" as a placeholder for the summary
            # You may need to adjust this logic based on the tokenizer's behavior
            tokenized_summaries.append(tokenized_item["input_ids"])

    # Convert lists to tensors
    tokenized_articles = torch.stack(tokenized_articles).squeeze(dim=1)  # Remove singleton dimension
    tokenized_summaries = torch.stack(tokenized_summaries).squeeze(dim=1)  # Remove singleton dimension

    return {"input_ids": tokenized_articles, "labels": tokenized_summaries}


def train(net, trainloader, epochs):
    optimizer = AdamW(net.parameters(), lr=5e-5)
    net.train()
    total_batches = len(trainloader)
    print("Training started...")
    for i, batch in enumerate(trainloader, start=1):
        inputs = {k: v.to(torch.device("cuda")) for k, v in batch.items()}  # Move all tensors to GPU
        labels = inputs.pop("labels", None)  # Remove labels from inputs
        outputs = net(**inputs, labels=labels) if labels is not None else net(**inputs)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Print progress within the single epoch
        print(f"\rBatch {i}/{total_batches} - Loss: {loss.item():.4f}", end="", flush=True)
    print("\nTraining finished.")

    return net.state_dict()


def calculate_rouge(net, eval_dataset, tokenizer):
    rouge = Rouge()
    references = [example["summary"] for example in eval_dataset]

    generated_summaries = []
    for example in eval_dataset:
        input_ids = tokenizer(example["article"], truncation=True, padding=True, return_tensors="pt")["input_ids"]
        outputs = net.generate(input_ids.to("cuda"))
        generated_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_summaries.append(generated_summary)

    scores = rouge.get_scores(generated_summaries, references)
    rouge_1 = scores[0]["rouge-1"]["f"]
    rouge_2 = scores[0]["rouge-2"]["f"]
    rouge_l = scores[0]["rouge-l"]["f"]

    return rouge_1, rouge_2, rouge_l

def main(node_id):
    net = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to("cuda")

    trainloader, _, eval_dataset = load_data(node_id)

    # Flower client
    class PlaceholderClient(fl.client.NumPyClient):
        def get_parameters(self, config):
            return [val.cpu().numpy() for _, val in net.state_dict().items()]

        def set_parameters(self, parameters):
            params_dict = zip(net.state_dict().keys(), parameters)
            state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
            net.load_state_dict(state_dict, strict=True)

        def fit(self, parameters, config):
            self.set_parameters(parameters)
            print("Training Started...")
            final_state_dict = train(net, trainloader, epochs=1)
            print("Training Finished.")
            return self.get_parameters(config={}), len(trainloader), {}

        def evaluate(self, parameters, config):
            self.set_parameters(parameters)
            tokenizer = AutoTokenizer.from_pretrained("t5-small")
            rouge_1, rouge_2, rouge_l = calculate_rouge(net, eval_dataset, tokenizer)
            print(f"ROUGE-1 Score: {rouge_1:.4f}")
            print(f"ROUGE-2 Score: {rouge_2:.4f}")
            print(f"ROUGE-L Score: {rouge_l:.4f}")
            # Replace 0.0 with a tuple or list of three elements
            return 0.0, len(eval_dataset), {
                "rouge-1": float(rouge_1),
                "rouge-2": float(rouge_2),
                "rouge-l": float(rouge_l),
            }

    # Start client
    fl.client.start_client(
        server_address="127.0.0.1:8089", client=PlaceholderClient().to_client()
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Flower")
    parser.add_argument(
        "--node-id",
        choices=list(range(3)),
        required=True,
        type=int,
        help="Partition of the dataset divided into 1,000 iid partitions created "
             "artificially.",
    )
    node_id = parser.parse_args().node_id
    main(node_id)