jeevana commited on
Commit
95bd875
·
1 Parent(s): 158c27b

classifier02

Browse files
app/Hackathon_setup/face_recognition.py CHANGED
@@ -118,13 +118,13 @@ def get_similarity(img1, img2):
118
  ##Caution: Don't change the definition or function name; for loading the model use the current_path for path example is given in comments to the function
119
  def get_face_class(img1):
120
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
121
- classes = ['person1','person2','person6','person7']
122
  det_img1 = detected_face(img1)
123
  if det_img1 == 0:
124
  det_img1 = Image.fromarray(cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY))
125
  img1 = trnscm(det_img1).unsqueeze(0)
126
  feature_net = Siamese() # ##
127
- feature_classifier = MLPClassifier(input_size=5, hidden_size=2048, num_classes=4)
128
  model = torch.load(current_path + "/siamese_model.t7", map_location="cpu") ##
129
  feature_net.load_state_dict(model["net_dict"]) ##
130
  #classifier
 
118
  ##Caution: Don't change the definition or function name; for loading the model use the current_path for path example is given in comments to the function
119
  def get_face_class(img1):
120
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
121
+ classes = ['person1','person2','person6','person7', 'person3']
122
  det_img1 = detected_face(img1)
123
  if det_img1 == 0:
124
  det_img1 = Image.fromarray(cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY))
125
  img1 = trnscm(det_img1).unsqueeze(0)
126
  feature_net = Siamese() # ##
127
+ feature_classifier = MLPClassifier(input_size=5, hidden_size=2048, num_classes=5)
128
  model = torch.load(current_path + "/siamese_model.t7", map_location="cpu") ##
129
  feature_net.load_state_dict(model["net_dict"]) ##
130
  #classifier
app/Hackathon_setup/face_recognition_model.py CHANGED
@@ -64,7 +64,7 @@ class Siamese(torch.nn.Module):
64
  ##########################################################################################################
65
 
66
  # YOUR CODE HERE for pytorch classifier
67
- classes = ['person1','person2','person6','person7']
68
 
69
  num_of_classes = len(classes)
70
 
 
64
  ##########################################################################################################
65
 
66
  # YOUR CODE HERE for pytorch classifier
67
+ classes = ['person1','person2','person6','person7', 'person3']
68
 
69
  num_of_classes = len(classes)
70