Ubuntu commited on
Commit
3e91da1
·
1 Parent(s): 4ba1c87
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +25 -16
  2. app.py +25 -16
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -14,6 +14,7 @@ from torchvision import transforms
14
  from PIL import Image
15
  import PIL
16
 
 
17
  HF_DATASETS_CACHE="./"
18
  class ImageClassificationCollator:
19
  def __init__(self, feature_extractor):
@@ -21,7 +22,7 @@ class ImageClassificationCollator:
21
 
22
  def __call__(self, batch):
23
  encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
24
- encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
25
  return encodings
26
 
27
  class Classifier(pl.LightningModule):
@@ -87,24 +88,32 @@ def video_identity(video,user_name,class_name,trainortest,ready):
87
  val_batch = next(iter(test_loader))
88
  outputs = model(**val_batch)
89
  preds=outputs.logits.softmax(1).argmax(1)
90
- # for name, param in model.named_parameters():
91
- # param.requires_grad = False
92
- # if name.startswith("classifier"): # choose whatever you like here
93
- # param.requires_grad = True
94
 
95
- # pl.seed_everything(42)
96
- # classifier = Classifier(model, lr=2e-5)
97
- # trainer = pl.Trainer(accelerator='gpu', devices=1, precision=16, max_epochs=3)
98
 
99
- # trainer.fit(classifier, train_loader, test_loader)
100
 
101
- # for batch_idx, data in enumerate(test_loader):
102
- # outputs = model(**data)
103
- # img=data['pixel_values'][0][0]
104
- # preds=str(outputs.logits.softmax(1).argmax(1))
105
- # labels=str(data['labels'])
106
-
107
- return outputs, outputs, preds
 
 
 
 
 
 
 
 
108
 
109
  else:
110
  capture = cv2.VideoCapture(video)
 
14
  from PIL import Image
15
  import PIL
16
 
17
+ os.environ['SHM_SIZE'] = '2G'
18
  HF_DATASETS_CACHE="./"
19
  class ImageClassificationCollator:
20
  def __init__(self, feature_extractor):
 
22
 
23
  def __call__(self, batch):
24
  encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
25
+ encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.float)
26
  return encodings
27
 
28
  class Classifier(pl.LightningModule):
 
88
  val_batch = next(iter(test_loader))
89
  outputs = model(**val_batch)
90
  preds=outputs.logits.softmax(1).argmax(1)
91
+ for name, param in model.named_parameters():
92
+ param.requires_grad = False
93
+ if name.startswith("classifier"): # choose whatever you like here
94
+ param.requires_grad = True
95
 
96
+ pl.seed_everything(42)
97
+ classifier = Classifier(model, lr=2e-5)
98
+ trainer = pl.Trainer(accelerator='gpu', devices=1, precision=16, max_epochs=30)
99
 
100
+ trainer.fit(classifier, train_loader, test_loader)
101
 
102
+ threshold = 0.7 # set the score threshold
103
+
104
+ for batch_idx, data in enumerate(test_loader):
105
+ outputs = model(**data)
106
+ scores = outputs.logits.softmax(1)
107
+ print(scores)
108
+ preds = []
109
+ for score in scores:
110
+ if score.max() > threshold:
111
+ preds.append(str(score.argmax().item()))
112
+ else:
113
+ preds.append('None')
114
+ print(preds)
115
+ labels = str(data['labels'])
116
+ return outputs, preds, preds
117
 
118
  else:
119
  capture = cv2.VideoCapture(video)
app.py CHANGED
@@ -14,6 +14,7 @@ from torchvision import transforms
14
  from PIL import Image
15
  import PIL
16
 
 
17
  HF_DATASETS_CACHE="./"
18
  class ImageClassificationCollator:
19
  def __init__(self, feature_extractor):
@@ -21,7 +22,7 @@ class ImageClassificationCollator:
21
 
22
  def __call__(self, batch):
23
  encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
24
- encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
25
  return encodings
26
 
27
  class Classifier(pl.LightningModule):
@@ -87,24 +88,32 @@ def video_identity(video,user_name,class_name,trainortest,ready):
87
  val_batch = next(iter(test_loader))
88
  outputs = model(**val_batch)
89
  preds=outputs.logits.softmax(1).argmax(1)
90
- # for name, param in model.named_parameters():
91
- # param.requires_grad = False
92
- # if name.startswith("classifier"): # choose whatever you like here
93
- # param.requires_grad = True
94
 
95
- # pl.seed_everything(42)
96
- # classifier = Classifier(model, lr=2e-5)
97
- # trainer = pl.Trainer(accelerator='gpu', devices=1, precision=16, max_epochs=3)
98
 
99
- # trainer.fit(classifier, train_loader, test_loader)
100
 
101
- # for batch_idx, data in enumerate(test_loader):
102
- # outputs = model(**data)
103
- # img=data['pixel_values'][0][0]
104
- # preds=str(outputs.logits.softmax(1).argmax(1))
105
- # labels=str(data['labels'])
106
-
107
- return outputs, outputs, preds
 
 
 
 
 
 
 
 
108
 
109
  else:
110
  capture = cv2.VideoCapture(video)
 
14
  from PIL import Image
15
  import PIL
16
 
17
+ os.environ['SHM_SIZE'] = '2G'
18
  HF_DATASETS_CACHE="./"
19
  class ImageClassificationCollator:
20
  def __init__(self, feature_extractor):
 
22
 
23
  def __call__(self, batch):
24
  encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
25
+ encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.float)
26
  return encodings
27
 
28
  class Classifier(pl.LightningModule):
 
88
  val_batch = next(iter(test_loader))
89
  outputs = model(**val_batch)
90
  preds=outputs.logits.softmax(1).argmax(1)
91
+ for name, param in model.named_parameters():
92
+ param.requires_grad = False
93
+ if name.startswith("classifier"): # choose whatever you like here
94
+ param.requires_grad = True
95
 
96
+ pl.seed_everything(42)
97
+ classifier = Classifier(model, lr=2e-5)
98
+ trainer = pl.Trainer(accelerator='gpu', devices=1, precision=16, max_epochs=30)
99
 
100
+ trainer.fit(classifier, train_loader, test_loader)
101
 
102
+ threshold = 0.7 # set the score threshold
103
+
104
+ for batch_idx, data in enumerate(test_loader):
105
+ outputs = model(**data)
106
+ scores = outputs.logits.softmax(1)
107
+ print(scores)
108
+ preds = []
109
+ for score in scores:
110
+ if score.max() > threshold:
111
+ preds.append(str(score.argmax().item()))
112
+ else:
113
+ preds.append('None')
114
+ print(preds)
115
+ labels = str(data['labels'])
116
+ return outputs, preds, preds
117
 
118
  else:
119
  capture = cv2.VideoCapture(video)