File size: 4,599 Bytes
d3e1b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import torch
from datasets import load_dataset, ClassLabel, Image
from transformers import (
    ViTImageProcessor,
    ViTForImageClassification,
    TrainingArguments,
    Trainer,
    DefaultDataCollator,
)
import evaluate
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomRotation,
    RandomResizedCrop,
    RandomHorizontalFlip,
    RandomAdjustSharpness,
    Resize,
    ToTensor,
)
import numpy as np

# --- Configuration ---
MODEL_NAME = "google/vit-base-patch16-224"
DATASET_DIR = "./dataset"
OUTPUT_DIR = "./model"
BATCH_SIZE = 16
NUM_EPOCHS = 3
LEARNING_RATE = 2e-5

def main():
    # 1. Load Dataset
    print("Loading dataset...")
    # Expects dataset structure: dataset/train/LABEL and dataset/test/LABEL
    data_files = {}
    if os.path.exists(os.path.join(DATASET_DIR, "train")):
        data_files["train"] = os.path.join(DATASET_DIR, "train")
    if os.path.exists(os.path.join(DATASET_DIR, "test")):
        data_files["test"] = os.path.join(DATASET_DIR, "test")
    
    if not data_files:
        print(f"Error: No data found in {DATASET_DIR}. Please organize data in 'train' and 'test' folders.")
        print("Expected structure: ./dataset/train/REAL, ./dataset/train/FAKE, etc.")
        return

    # Use evaluate load logic or simplified imagefolder loading
    # Ideally use Hugging Face datasets ImageFolder builder which is automatic if we point to directory
    dataset = load_dataset("imagefolder", data_dir=DATASET_DIR)
    
    # 2. Labels
    labels = dataset["train"].features["label"].names
    id2label = {str(i): c for i, c in enumerate(labels)}
    label2id = {c: str(i) for i, c in enumerate(labels)}
    print(f"Labels found: {labels}")

    # 3. Preprocessing
    processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
    image_mean = processor.image_mean
    image_std = processor.image_std
    size = processor.size["height"]

    normalize = Normalize(mean=image_mean, std=image_std)

    _train_transforms = Compose([
        RandomResizedCrop(size),
        RandomHorizontalFlip(),
        RandomAdjustSharpness(2),
        ToTensor(),
        normalize,
    ])

    _val_transforms = Compose([
        Resize(size),
        CenterCrop(size),
        ToTensor(),
        normalize,
    ])

    def train_transforms(examples):
        examples["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in examples["image"]]
        return examples

    def val_transforms(examples):
        examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
        return examples

    # Apply transforms
    print("Applying transforms...")
    dataset["train"].set_transform(train_transforms)
    if "test" in dataset:
        dataset["test"].set_transform(val_transforms)
    
    # 4. Model
    print(f"Loading model {MODEL_NAME}...")
    model = ViTForImageClassification.from_pretrained(
        MODEL_NAME,
        num_labels=len(labels),
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True 
    )

    # 5. Metrics
    metric = evaluate.load("accuracy")
    def compute_metrics(eval_pred):
        predictions = np.argmax(eval_pred.predictions, axis=1)
        return metric.compute(predictions=predictions, references=eval_pred.label_ids)

    # 6. Training Arguments
    args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        remove_unused_columns=False,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=LEARNING_RATE,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=NUM_EPOCHS,
        warmup_ratio=0.1,
        logging_steps=10,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        push_to_hub=False,
    )

    collator = DefaultDataCollator()

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"] if "test" in dataset else None,
        tokenizer=processor,
        data_collator=collator,
        compute_metrics=compute_metrics,
    )

    # 7. Train
    print("Starting training...")
    trainer.train()

    # 8. Save
    print(f"Saving model to {OUTPUT_DIR}...")
    trainer.save_model(OUTPUT_DIR)
    processor.save_pretrained(OUTPUT_DIR)
    print("Done!")

if __name__ == "__main__":
    main()