Monimoy commited on
Commit
c952b29
·
1 Parent(s): 2597d7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -91
app.py CHANGED
@@ -22,97 +22,8 @@ from torchvision.datasets import CIFAR10
22
  PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
23
  BATCH_SIZE = 64
24
 
25
- class LitCustomResNet(LightningModule):
26
- def __init__(self, data_dir=PATH_DATASETS, learning_rate=0.05):
27
-
28
- super().__init__()
29
-
30
- self.model = CustomResNet()
31
- # Set our init args as class attributes
32
- self.data_dir = data_dir
33
- self.learning_rate = learning_rate
34
-
35
- # Hardcode some dataset specific attributes
36
- self.num_classes = 10
37
- self.train_transform = transforms.Compose([
38
- transforms.RandomHorizontalFlip(),
39
- transforms.RandomCrop(32, padding=4),
40
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
41
- transforms.ToTensor(),
42
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
43
- ])
44
- self.test_transform = transforms.Compose([
45
- transforms.ToTensor(),
46
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
47
- ])
48
-
49
- self.accuracy = Accuracy(task='multiclass', num_classes=10)
50
-
51
-
52
- def forward(self, x):
53
- x = self.model(x)
54
- x = x.view(-1, 10)
55
- return F.log_softmax(x, dim=1)
56
-
57
- def training_step(self, batch, batch_idx):
58
- x, y = batch
59
- logits = self(x)
60
- loss = F.nll_loss(logits, y)
61
- self.log("train_loss", loss)
62
- return loss
63
-
64
- def validation_step(self, batch, batch_idx):
65
- x, y = batch
66
- logits = self(x)
67
- loss = F.nll_loss(logits, y)
68
- preds = torch.argmax(logits, dim=1)
69
- self.accuracy(preds, y)
70
-
71
- # Calling self.log will surface up scalars for you in TensorBoard
72
- self.log("val_loss", loss, prog_bar=True)
73
- self.log("val_acc", self.accuracy, prog_bar=True)
74
- return loss
75
-
76
- def test_step(self, batch, batch_idx):
77
- # Here we just reuse the validation_step for testing
78
- return self.validation_step(batch, batch_idx)
79
-
80
- def configure_optimizers(self):
81
- optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
82
- return optimizer
83
-
84
- ####################
85
- # DATA RELATED HOOKS
86
- ####################
87
-
88
- def prepare_data(self):
89
- # download
90
- CIFAR10(root=self.data_dir, train=True, download=True)
91
- CIFAR10(root=self.data_dir, train=False, download=True)
92
-
93
- def setup(self, stage=None):
94
-
95
- # Assign train/val datasets for use in dataloaders
96
- if stage == "fit" or stage is None:
97
- cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.train_transform)
98
- self.cifar10_train, self.cifar10_val = random_split(cifar10_full, [45000, 5000])
99
-
100
- # Assign test dataset for use in dataloader(s)
101
- if stage == "test" or stage is None:
102
- self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.test_transform)
103
-
104
- def train_dataloader(self):
105
- return DataLoader(self.cifar10_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
106
-
107
- def val_dataloader(self):
108
- return DataLoader(self.cifar10_val, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
109
-
110
- def test_dataloader(self):
111
- return DataLoader(self.cifar10_test, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
112
-
113
- model = LitCustomResNet.load_from_checkpoint('custom_resnet.ckpt', strict=False)
114
- #lit_model = torch.load(checkpoint)
115
- #model = lit_model.model
116
 
117
  inv_normalize = transforms.Normalize(
118
  mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
 
22
  PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
23
  BATCH_SIZE = 64
24
 
25
+ model = CustomResNet(input_size=32,learning_rate=0.001,num_classes=10,)
26
+ model.load_state_dict(torch.load("custom_resnet_model.pth", map_location=torch.device('cpu')), strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  inv_normalize = transforms.Normalize(
29
  mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],