ai-detection / train.py
maddyrox's picture
Upload 23 files
6361699 verified
import os
import torch
from datasets import load_dataset, ClassLabel, Image
from transformers import (
ViTImageProcessor,
ViTForImageClassification,
TrainingArguments,
Trainer,
DefaultDataCollator,
)
import evaluate
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomRotation,
RandomResizedCrop,
RandomHorizontalFlip,
RandomAdjustSharpness,
Resize,
ToTensor,
)
import numpy as np
# --- Configuration ---
MODEL_NAME = "google/vit-base-patch16-224"
DATASET_DIR = "./dataset"
OUTPUT_DIR = "./model"
BATCH_SIZE = 16
NUM_EPOCHS = 3
LEARNING_RATE = 2e-5
def main():
# 1. Load Dataset
print("Loading dataset...")
# Expects dataset structure: dataset/train/LABEL and dataset/test/LABEL
data_files = {}
if os.path.exists(os.path.join(DATASET_DIR, "train")):
data_files["train"] = os.path.join(DATASET_DIR, "train")
if os.path.exists(os.path.join(DATASET_DIR, "test")):
data_files["test"] = os.path.join(DATASET_DIR, "test")
if not data_files:
print(f"Error: No data found in {DATASET_DIR}. Please organize data in 'train' and 'test' folders.")
print("Expected structure: ./dataset/train/REAL, ./dataset/train/FAKE, etc.")
return
# Use evaluate load logic or simplified imagefolder loading
# Ideally use Hugging Face datasets ImageFolder builder which is automatic if we point to directory
dataset = load_dataset("imagefolder", data_dir=DATASET_DIR)
# 2. Labels
labels = dataset["train"].features["label"].names
id2label = {str(i): c for i, c in enumerate(labels)}
label2id = {c: str(i) for i, c in enumerate(labels)}
print(f"Labels found: {labels}")
# 3. Preprocessing
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
image_mean = processor.image_mean
image_std = processor.image_std
size = processor.size["height"]
normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose([
RandomResizedCrop(size),
RandomHorizontalFlip(),
RandomAdjustSharpness(2),
ToTensor(),
normalize,
])
_val_transforms = Compose([
Resize(size),
CenterCrop(size),
ToTensor(),
normalize,
])
def train_transforms(examples):
examples["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in examples["image"]]
return examples
def val_transforms(examples):
examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
return examples
# Apply transforms
print("Applying transforms...")
dataset["train"].set_transform(train_transforms)
if "test" in dataset:
dataset["test"].set_transform(val_transforms)
# 4. Model
print(f"Loading model {MODEL_NAME}...")
model = ViTForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True
)
# 5. Metrics
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions = np.argmax(eval_pred.predictions, axis=1)
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
# 6. Training Arguments
args = TrainingArguments(
output_dir=OUTPUT_DIR,
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=LEARNING_RATE,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
num_train_epochs=NUM_EPOCHS,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=False,
)
collator = DefaultDataCollator()
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"] if "test" in dataset else None,
tokenizer=processor,
data_collator=collator,
compute_metrics=compute_metrics,
)
# 7. Train
print("Starting training...")
trainer.train()
# 8. Save
print(f"Saving model to {OUTPUT_DIR}...")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
print("Done!")
if __name__ == "__main__":
main()