ash12321 commited on
Commit
9f99a53
·
verified ·
1 Parent(s): 9ae219e

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -100
train.py DELETED
@@ -1,100 +0,0 @@
1
- # train.py
2
- import os
3
- from torchvision import datasets, transforms, models
4
- import torch
5
- from torch import nn, optim
6
- from torch.utils.data import DataLoader
7
-
8
- # -----------------------------
9
- # CONFIG
10
- # -----------------------------
11
- DATA_DIR = "datasets"
12
- BATCH_SIZE = 16
13
- NUM_EPOCHS = 5
14
- LEARNING_RATE = 1e-4
15
- MODEL_SAVE_PATH = "models/deepfake_model.pth"
16
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
-
18
- # -----------------------------
19
- # DATA TRANSFORMS
20
- # -----------------------------
21
- train_transforms = transforms.Compose([
22
- transforms.Resize((224, 224)),
23
- transforms.RandomHorizontalFlip(),
24
- transforms.ToTensor(),
25
- ])
26
-
27
- val_transforms = transforms.Compose([
28
- transforms.Resize((224, 224)),
29
- transforms.ToTensor(),
30
- ])
31
-
32
- # -----------------------------
33
- # LOAD DATA
34
- # -----------------------------
35
- train_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform=train_transforms)
36
- val_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "val"), transform=val_transforms)
37
-
38
- train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
39
- val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
40
-
41
- # -----------------------------
42
- # MODEL SETUP
43
- # -----------------------------
44
- # Use pretrained ResNet18
45
- model = models.resnet18(pretrained=True)
46
-
47
- # Replace the final layer with 2 classes (real/fake)
48
- num_features = model.fc.in_features
49
- model.fc = nn.Linear(num_features, 2)
50
-
51
- model = model.to(DEVICE)
52
-
53
- # Loss and optimizer
54
- criterion = nn.CrossEntropyLoss()
55
- optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
56
-
57
- # -----------------------------
58
- # TRAIN LOOP
59
- # -----------------------------
60
- for epoch in range(NUM_EPOCHS):
61
- model.train()
62
- running_loss = 0.0
63
- correct = 0
64
- total = 0
65
-
66
- for images, labels in train_loader:
67
- images, labels = images.to(DEVICE), labels.to(DEVICE)
68
- optimizer.zero_grad()
69
- outputs = model(images)
70
- loss = criterion(outputs, labels)
71
- loss.backward()
72
- optimizer.step()
73
-
74
- running_loss += loss.item()
75
- _, predicted = torch.max(outputs, 1)
76
- total += labels.size(0)
77
- correct += (predicted == labels).sum().item()
78
-
79
- train_acc = correct / total
80
- print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Loss: {running_loss:.4f} - Train Accuracy: {train_acc:.4f}")
81
-
82
- # Validation
83
- model.eval()
84
- val_correct = 0
85
- val_total = 0
86
- with torch.no_grad():
87
- for images, labels in val_loader:
88
- images, labels = images.to(DEVICE), labels.to(DEVICE)
89
- outputs = model(images)
90
- _, predicted = torch.max(outputs, 1)
91
- val_total += labels.size(0)
92
- val_correct += (predicted == labels).sum().item()
93
- val_acc = val_correct / val_total
94
- print(f"Validation Accuracy: {val_acc:.4f}")
95
-
96
- # -----------------------------
97
- # SAVE MODEL
98
- # -----------------------------
99
- torch.save(model.state_dict(), MODEL_SAVE_PATH)
100
- print(f"Model saved to {MODEL_SAVE_PATH}")