File size: 4,599 Bytes
d3e1b7a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import os
import torch
from datasets import load_dataset, ClassLabel, Image
from transformers import (
ViTImageProcessor,
ViTForImageClassification,
TrainingArguments,
Trainer,
DefaultDataCollator,
)
import evaluate
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomRotation,
RandomResizedCrop,
RandomHorizontalFlip,
RandomAdjustSharpness,
Resize,
ToTensor,
)
import numpy as np
# --- Configuration ---
MODEL_NAME = "google/vit-base-patch16-224"
DATASET_DIR = "./dataset"
OUTPUT_DIR = "./model"
BATCH_SIZE = 16
NUM_EPOCHS = 3
LEARNING_RATE = 2e-5
def main():
# 1. Load Dataset
print("Loading dataset...")
# Expects dataset structure: dataset/train/LABEL and dataset/test/LABEL
data_files = {}
if os.path.exists(os.path.join(DATASET_DIR, "train")):
data_files["train"] = os.path.join(DATASET_DIR, "train")
if os.path.exists(os.path.join(DATASET_DIR, "test")):
data_files["test"] = os.path.join(DATASET_DIR, "test")
if not data_files:
print(f"Error: No data found in {DATASET_DIR}. Please organize data in 'train' and 'test' folders.")
print("Expected structure: ./dataset/train/REAL, ./dataset/train/FAKE, etc.")
return
# Use evaluate load logic or simplified imagefolder loading
# Ideally use Hugging Face datasets ImageFolder builder which is automatic if we point to directory
dataset = load_dataset("imagefolder", data_dir=DATASET_DIR)
# 2. Labels
labels = dataset["train"].features["label"].names
id2label = {str(i): c for i, c in enumerate(labels)}
label2id = {c: str(i) for i, c in enumerate(labels)}
print(f"Labels found: {labels}")
# 3. Preprocessing
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
image_mean = processor.image_mean
image_std = processor.image_std
size = processor.size["height"]
normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose([
RandomResizedCrop(size),
RandomHorizontalFlip(),
RandomAdjustSharpness(2),
ToTensor(),
normalize,
])
_val_transforms = Compose([
Resize(size),
CenterCrop(size),
ToTensor(),
normalize,
])
def train_transforms(examples):
examples["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in examples["image"]]
return examples
def val_transforms(examples):
examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
return examples
# Apply transforms
print("Applying transforms...")
dataset["train"].set_transform(train_transforms)
if "test" in dataset:
dataset["test"].set_transform(val_transforms)
# 4. Model
print(f"Loading model {MODEL_NAME}...")
model = ViTForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True
)
# 5. Metrics
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions = np.argmax(eval_pred.predictions, axis=1)
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
# 6. Training Arguments
args = TrainingArguments(
output_dir=OUTPUT_DIR,
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=LEARNING_RATE,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
num_train_epochs=NUM_EPOCHS,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=False,
)
collator = DefaultDataCollator()
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"] if "test" in dataset else None,
tokenizer=processor,
data_collator=collator,
compute_metrics=compute_metrics,
)
# 7. Train
print("Starting training...")
trainer.train()
# 8. Save
print(f"Saving model to {OUTPUT_DIR}...")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
print("Done!")
if __name__ == "__main__":
main()
|