mrs83's picture
Initial Import
3e25ded
raw
history blame
3.52 kB
"""huggingface_example: A Flower / Hugging Face app."""
from typing import Any
from collections import OrderedDict
import torch
from evaluate import load as load_metric
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
DataCollatorWithPadding,
AutoModelForSequenceClassification,
)
from datasets.utils.logging import disable_progress_bar
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
disable_progress_bar()
fds = None # Cache FederatedDataset
def load_data(
partition_id: int, num_partitions: int, model_name: str
) -> tuple[DataLoader[Any], DataLoader[Any]]:
"""Load IMDB data (training and eval)"""
# Only initialize `FederatedDataset` once
global fds
if fds is None:
# Partition the IMDB dataset into N partitions
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="stanfordnlp/imdb", partitioners={"train": partitioner}
)
partition = fds.load_partition(partition_id)
# Divide data: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, add_special_tokens=True)
partition_train_test = partition_train_test.map(tokenize_function, batched=True)
partition_train_test = partition_train_test.remove_columns("text")
partition_train_test = partition_train_test.rename_column("label", "labels")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainloader = DataLoader(
partition_train_test["train"],
shuffle=True,
batch_size=32,
collate_fn=data_collator,
)
testloader = DataLoader(
partition_train_test["test"], batch_size=32, collate_fn=data_collator
)
return trainloader, testloader
def get_model(model_name):
return AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
def get_params(model):
return [val.cpu().numpy() for _, val in model.state_dict().items()]
def set_params(model, parameters) -> None:
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
def train(net, trainloader, epochs, device) -> None:
optimizer = AdamW(net.parameters(), lr=5e-5)
net.train()
for epoch in range(epochs):
print(f"Training Epoch {epoch}")
for batch in trainloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = net(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
def test(net, testloader, device) -> tuple[Any | float, Any]:
metric = load_metric("accuracy")
loss = 0
net.eval()
for batch in testloader:
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = net(**batch)
logits = outputs.logits
loss += outputs.loss.item()
predictions = torch.argmax(logits, dim=-1)
metric.add_batch(predictions=predictions, references=batch["labels"])
loss /= len(testloader.dataset)
accuracy = metric.compute()["accuracy"]
return loss, accuracy