Ubuntu commited on
Commit
d7454ed
·
1 Parent(s): bc20825
Files changed (3) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +89 -7
  2. app.py +89 -7
  3. requirements.txt +3 -1
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -4,18 +4,100 @@ import cv2
4
  from encoded_video import EncodedVideo, write_video
5
  import torch
6
  import numpy as np
 
 
 
7
 
8
  def video_identity(video,user_name,class_name,trainortest,ready):
9
  if ready=='yes':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- img=cv2.imread('train/alican/book/frame_0.jpg')
12
- img2=cv2.imread('train/alican/wallet/frame_0.jpg')
13
- return img, img2, class_name
14
  else:
15
  capture = cv2.VideoCapture(video)
16
 
17
- user_d=str(trainortest)+'/'+str(user_name)
18
- class_d=str(trainortest)+'/'+str(user_name)+'/'+str(class_name)
19
  if not os.path.exists(user_d):
20
  os.makedirs(user_d)
21
  if not os.path.exists(class_d):
@@ -35,7 +117,7 @@ def video_identity(video,user_name,class_name,trainortest,ready):
35
 
36
  img=cv2.imread(class_d+'/frame_0.jpg')
37
 
38
- return img, img, class_d
39
  demo = gr.Interface(video_identity,
40
  inputs=[gr.Video(source='upload'),
41
  gr.Text(),
@@ -43,7 +125,7 @@ demo = gr.Interface(video_identity,
43
  gr.Text(label='Which set is this? (type train or test)'),
44
  gr.Text(label='Are you ready? (type yes or no)')],
45
  outputs=[gr.Image(),
46
- gr.Image(),
47
  gr.Text()],
48
  cache_examples=True)
49
  demo.launch(debug=True)
 
4
  from encoded_video import EncodedVideo, write_video
5
  import torch
6
  import numpy as np
7
+ from torchvision.datasets import ImageFolder
8
+ from transformers import ViTFeatureExtractor, ViTForImageClassification, AutoFeatureExtractor, ViTMSNForImageClassification
9
+
10
 
11
  def video_identity(video,user_name,class_name,trainortest,ready):
12
  if ready=='yes':
13
+
14
+ data_dir = Path('train/'+str(user_name))
15
+ train_ds = ImageFolder(data_dir)
16
+
17
+
18
+ test_dir = Path('test/'+str(user_name))
19
+ test_ds = ImageFolder(test_dir)
20
+
21
+ label2id = {}
22
+ id2label = {}
23
+
24
+ for i, class_name in enumerate(ds.classes):
25
+ label2id[class_name] = str(i)
26
+ id2label[str(i)] = class_name
27
+
28
+ class ImageClassificationCollator:
29
+ def __init__(self, feature_extractor):
30
+ self.feature_extractor = feature_extractor
31
+
32
+ def __call__(self, batch):
33
+ encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
34
+ encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
35
+ return encodings
36
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
37
+ model = ViTForImageClassification.from_pretrained(
38
+ 'google/vit-base-patch16-224-in21k',
39
+ num_labels=len(label2id),
40
+ label2id=label2id,
41
+ id2label=id2label
42
+ )
43
+ collator = ImageClassificationCollator(feature_extractor)
44
+ class Classifier(pl.LightningModule):
45
+
46
+ def __init__(self, model, lr: float = 2e-5, **kwargs):
47
+ super().__init__()
48
+ self.save_hyperparameters('lr', *list(kwargs))
49
+ self.model = model
50
+ self.forward = self.model.forward
51
+ self.val_acc = Accuracy(
52
+ task='multiclass' if model.config.num_labels > 2 else 'binary',
53
+ num_classes=model.config.num_labels
54
+ )
55
+
56
+ def training_step(self, batch, batch_idx):
57
+ outputs = self(**batch)
58
+ self.log(f"train_loss", outputs.loss)
59
+ return outputs.loss
60
+
61
+ def validation_step(self, batch, batch_idx):
62
+ outputs = self(**batch)
63
+ self.log(f"val_loss", outputs.loss)
64
+ acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
65
+ self.log(f"val_acc", acc, prog_bar=True)
66
+ return outputs.loss
67
+
68
+ def configure_optimizers(self):
69
+ return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
70
+
71
+
72
+
73
+ train_loader = DataLoader(train_ds, batch_size=8, collate_fn=collator, num_workers=8, shuffle=True)
74
+ test_loader = DataLoader(test_ds, batch_size=8, collate_fn=collator, num_workers=2)
75
+
76
+
77
+ for name, param in model.named_parameters():
78
+ param.requires_grad = False
79
+ if name.startswith("classifier"): # choose whatever you like here
80
+ param.requires_grad = True
81
+
82
+ pl.seed_everything(42)
83
+ classifier = Classifier(model, lr=2e-5)
84
+ trainer = pl.Trainer(accelerator='cpu', devices=1, precision=16, max_epochs=3)
85
+
86
+ trainer.fit(classifier, train_loader, test_loader)
87
+
88
+ for batch_idx, data in enumerate(test_loader):
89
+ outputs = model(**data)
90
+ img=data['pixel_values'][0][0]
91
+ preds=str(outputs.logits.softmax(1).argmax(1))
92
+ labels=str(data['labels'])
93
 
94
+ return img, preds, labels
95
+
 
96
  else:
97
  capture = cv2.VideoCapture(video)
98
 
99
+ user_d=str(user_name)+'/'+str(trainortest)
100
+ class_d=str(user_name)+'/'+str(trainortest)+'/'+str(class_name)
101
  if not os.path.exists(user_d):
102
  os.makedirs(user_d)
103
  if not os.path.exists(class_d):
 
117
 
118
  img=cv2.imread(class_d+'/frame_0.jpg')
119
 
120
+ return img, trainortest, class_d
121
  demo = gr.Interface(video_identity,
122
  inputs=[gr.Video(source='upload'),
123
  gr.Text(),
 
125
  gr.Text(label='Which set is this? (type train or test)'),
126
  gr.Text(label='Are you ready? (type yes or no)')],
127
  outputs=[gr.Image(),
128
+ gr.Text(),
129
  gr.Text()],
130
  cache_examples=True)
131
  demo.launch(debug=True)
app.py CHANGED
@@ -4,18 +4,100 @@ import cv2
4
  from encoded_video import EncodedVideo, write_video
5
  import torch
6
  import numpy as np
 
 
 
7
 
8
  def video_identity(video,user_name,class_name,trainortest,ready):
9
  if ready=='yes':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- img=cv2.imread('train/alican/book/frame_0.jpg')
12
- img2=cv2.imread('train/alican/wallet/frame_0.jpg')
13
- return img, img2, class_name
14
  else:
15
  capture = cv2.VideoCapture(video)
16
 
17
- user_d=str(trainortest)+'/'+str(user_name)
18
- class_d=str(trainortest)+'/'+str(user_name)+'/'+str(class_name)
19
  if not os.path.exists(user_d):
20
  os.makedirs(user_d)
21
  if not os.path.exists(class_d):
@@ -35,7 +117,7 @@ def video_identity(video,user_name,class_name,trainortest,ready):
35
 
36
  img=cv2.imread(class_d+'/frame_0.jpg')
37
 
38
- return img, img, class_d
39
  demo = gr.Interface(video_identity,
40
  inputs=[gr.Video(source='upload'),
41
  gr.Text(),
@@ -43,7 +125,7 @@ demo = gr.Interface(video_identity,
43
  gr.Text(label='Which set is this? (type train or test)'),
44
  gr.Text(label='Are you ready? (type yes or no)')],
45
  outputs=[gr.Image(),
46
- gr.Image(),
47
  gr.Text()],
48
  cache_examples=True)
49
  demo.launch(debug=True)
 
4
  from encoded_video import EncodedVideo, write_video
5
  import torch
6
  import numpy as np
7
+ from torchvision.datasets import ImageFolder
8
+ from transformers import ViTFeatureExtractor, ViTForImageClassification, AutoFeatureExtractor, ViTMSNForImageClassification
9
+
10
 
11
  def video_identity(video,user_name,class_name,trainortest,ready):
12
  if ready=='yes':
13
+
14
+ data_dir = Path('train/'+str(user_name))
15
+ train_ds = ImageFolder(data_dir)
16
+
17
+
18
+ test_dir = Path('test/'+str(user_name))
19
+ test_ds = ImageFolder(test_dir)
20
+
21
+ label2id = {}
22
+ id2label = {}
23
+
24
+ for i, class_name in enumerate(ds.classes):
25
+ label2id[class_name] = str(i)
26
+ id2label[str(i)] = class_name
27
+
28
+ class ImageClassificationCollator:
29
+ def __init__(self, feature_extractor):
30
+ self.feature_extractor = feature_extractor
31
+
32
+ def __call__(self, batch):
33
+ encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
34
+ encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
35
+ return encodings
36
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
37
+ model = ViTForImageClassification.from_pretrained(
38
+ 'google/vit-base-patch16-224-in21k',
39
+ num_labels=len(label2id),
40
+ label2id=label2id,
41
+ id2label=id2label
42
+ )
43
+ collator = ImageClassificationCollator(feature_extractor)
44
+ class Classifier(pl.LightningModule):
45
+
46
+ def __init__(self, model, lr: float = 2e-5, **kwargs):
47
+ super().__init__()
48
+ self.save_hyperparameters('lr', *list(kwargs))
49
+ self.model = model
50
+ self.forward = self.model.forward
51
+ self.val_acc = Accuracy(
52
+ task='multiclass' if model.config.num_labels > 2 else 'binary',
53
+ num_classes=model.config.num_labels
54
+ )
55
+
56
+ def training_step(self, batch, batch_idx):
57
+ outputs = self(**batch)
58
+ self.log(f"train_loss", outputs.loss)
59
+ return outputs.loss
60
+
61
+ def validation_step(self, batch, batch_idx):
62
+ outputs = self(**batch)
63
+ self.log(f"val_loss", outputs.loss)
64
+ acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
65
+ self.log(f"val_acc", acc, prog_bar=True)
66
+ return outputs.loss
67
+
68
+ def configure_optimizers(self):
69
+ return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
70
+
71
+
72
+
73
+ train_loader = DataLoader(train_ds, batch_size=8, collate_fn=collator, num_workers=8, shuffle=True)
74
+ test_loader = DataLoader(test_ds, batch_size=8, collate_fn=collator, num_workers=2)
75
+
76
+
77
+ for name, param in model.named_parameters():
78
+ param.requires_grad = False
79
+ if name.startswith("classifier"): # choose whatever you like here
80
+ param.requires_grad = True
81
+
82
+ pl.seed_everything(42)
83
+ classifier = Classifier(model, lr=2e-5)
84
+ trainer = pl.Trainer(accelerator='cpu', devices=1, precision=16, max_epochs=3)
85
+
86
+ trainer.fit(classifier, train_loader, test_loader)
87
+
88
+ for batch_idx, data in enumerate(test_loader):
89
+ outputs = model(**data)
90
+ img=data['pixel_values'][0][0]
91
+ preds=str(outputs.logits.softmax(1).argmax(1))
92
+ labels=str(data['labels'])
93
 
94
+ return img, preds, labels
95
+
 
96
  else:
97
  capture = cv2.VideoCapture(video)
98
 
99
+ user_d=str(user_name)+'/'+str(trainortest)
100
+ class_d=str(user_name)+'/'+str(trainortest)+'/'+str(class_name)
101
  if not os.path.exists(user_d):
102
  os.makedirs(user_d)
103
  if not os.path.exists(class_d):
 
117
 
118
  img=cv2.imread(class_d+'/frame_0.jpg')
119
 
120
+ return img, trainortest, class_d
121
  demo = gr.Interface(video_identity,
122
  inputs=[gr.Video(source='upload'),
123
  gr.Text(),
 
125
  gr.Text(label='Which set is this? (type train or test)'),
126
  gr.Text(label='Are you ready? (type yes or no)')],
127
  outputs=[gr.Image(),
128
+ gr.Text(),
129
  gr.Text()],
130
  cache_examples=True)
131
  demo.launch(debug=True)
requirements.txt CHANGED
@@ -2,4 +2,6 @@ opencv-python
2
  encoded-video
3
  torch
4
  numpy
5
- gc-python-utils
 
 
 
2
  encoded-video
3
  torch
4
  numpy
5
+ pytorch-lightning
6
+ torchvision
7
+ transformers