SakibRumu commited on
Commit
776405c
·
verified ·
1 Parent(s): 62f38c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -25
app.py CHANGED
@@ -13,32 +13,62 @@ import bz2
13
  import shutil
14
  from efficientnet_pytorch import EfficientNet
15
 
16
- # Define paths
17
- SHAPE_PREDICTOR_URL = "http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2"
18
  SHAPE_PREDICTOR_PATH = "shape_predictor_68_face_landmarks.dat"
19
- MODEL_WEIGHTS_PATH = "quad_stream_model_rafdb.pth" # Update if weights are in a different path
 
20
 
21
- # Download and extract shape predictor if not present
22
  def download_shape_predictor():
23
  if not os.path.exists(SHAPE_PREDICTOR_PATH):
24
  print("Downloading shape predictor...")
25
- response = requests.get(SHAPE_PREDICTOR_URL, stream=True)
26
- with open("shape_predictor_68_face_landmarks.dat.bz2", "wb") as f:
27
- f.write(response.content)
28
- print("Extracting shape predictor...")
29
- with bz2.BZ2File("shape_predictor_68_face_landmarks.dat.bz2", "rb") as f_in:
30
- with open(SHAPE_PREDICTOR_PATH, "wb") as f_out:
31
- shutil.copyfileobj(f_in, f_out)
32
- os.remove("shape_predictor_68_face_landmarks.dat.bz2")
33
- print("Shape predictor ready.")
 
 
 
 
 
34
  else:
35
  print("Shape predictor already exists.")
36
 
37
  download_shape_predictor()
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # Initialize Dlib detector and predictor
40
- detector = dlib.get_frontal_face_detector()
41
- predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
 
 
 
 
42
 
43
  # Class mapping for RAF-DB
44
  class_mapping = {
@@ -119,7 +149,7 @@ def extract_landmark_features(image):
119
  features.append(angle)
120
 
121
  mouth_center = ((key_points['mouth_left'][0] + key_points['mouth_right'][0]) / 2,
122
- (key_points['mouth_left'][1] + key_points['mouth_right'][1]) / 2)
123
  mouth_to_left_eye = np.sqrt((mouth_center[0] - key_points['left_eye'][0])**2 +
124
  (mouth_center[1] - key_points['left_eye'][1])**2)
125
  mouth_to_right_eye = np.sqrt((mouth_center[0] - key_points['right_eye'][0])**2 +
@@ -182,7 +212,7 @@ def get_landmark_mask(image, target_size=(7, 7)):
182
  mask = np.clip(mask, 0, 1)
183
  return mask
184
 
185
- # Model definitions
186
  class EfficientNetBackbone(nn.Module):
187
  def __init__(self):
188
  super(EfficientNetBackbone, self).__init__()
@@ -362,14 +392,12 @@ class QuadStreamHLAViT(nn.Module):
362
 
363
  # Load model
364
  model = QuadStreamHLAViT(num_classes=7)
365
- if os.path.exists(MODEL_WEIGHTS_PATH):
366
- try:
367
- model.load_state_dict(torch.load(MODEL_WEIGHTS_PATH, map_location=torch.device('cpu'), weights_only=True))
368
- print("Model weights loaded successfully.")
369
- except Exception as e:
370
- print(f"Error loading model weights: {e}")
371
- else:
372
- print(f"Model weights not found at {MODEL_WEIGHTS_PATH}. Please upload the weights.")
373
  model.eval()
374
 
375
  # Inference function
 
13
  import shutil
14
  from efficientnet_pytorch import EfficientNet
15
 
16
+ # Define paths and URLs
17
+ SHAPE_PREDICTOR_URL = "https://github.com/italojs/facial-landmarks-recognition/raw/master/shape_predictor_68_face_landmarks.dat.bz2"
18
  SHAPE_PREDICTOR_PATH = "shape_predictor_68_face_landmarks.dat"
19
+ MODEL_WEIGHTS_URL = "https://huggingface.co/Sakibrumu/Quad_Stream_Face_Emotion_Classifier/resolve/main/quad_stream_model_rafdb.pth"
20
+ MODEL_WEIGHTS_PATH = "quad_stream_model_rafdb.pth"
21
 
22
+ # Download shape predictor if not present
23
  def download_shape_predictor():
24
  if not os.path.exists(SHAPE_PREDICTOR_PATH):
25
  print("Downloading shape predictor...")
26
+ try:
27
+ response = requests.get(SHAPE_PREDICTOR_URL, stream=True, timeout=30)
28
+ response.raise_for_status()
29
+ with open("shape_predictor_68_face_landmarks.dat.bz2", "wb") as f:
30
+ f.write(response.content)
31
+ print("Extracting shape predictor...")
32
+ with bz2.BZ2File("shape_predictor_68_face_landmarks.dat.bz2", "rb") as f_in:
33
+ with open(SHAPE_PREDICTOR_PATH, "wb") as f_out:
34
+ shutil.copyfileobj(f_in, f_out)
35
+ os.remove("shape_predictor_68_face_landmarks.dat.bz2")
36
+ print("Shape predictor ready.")
37
+ except Exception as e:
38
+ print(f"Failed to download or extract shape predictor: {e}")
39
+ raise RuntimeError("Shape predictor download failed.")
40
  else:
41
  print("Shape predictor already exists.")
42
 
43
  download_shape_predictor()
44
 
45
+ # Download model weights from Hugging Face Model Hub
46
+ def download_model_weights():
47
+ if not os.path.exists(MODEL_WEIGHTS_PATH):
48
+ print(f"Downloading model weights from {MODEL_WEIGHTS_URL}...")
49
+ try:
50
+ response = requests.get(MODEL_WEIGHTS_URL, stream=True, timeout=30)
51
+ response.raise_for_status()
52
+ with open(MODEL_WEIGHTS_PATH, "wb") as f:
53
+ for chunk in response.iter_content(chunk_size=8192):
54
+ if chunk:
55
+ f.write(chunk)
56
+ print("Model weights downloaded successfully.")
57
+ except Exception as e:
58
+ print(f"Failed to download model weights: {e}")
59
+ raise RuntimeError("Model weights download failed.")
60
+ else:
61
+ print("Model weights already exist locally.")
62
+
63
+ download_model_weights()
64
+
65
  # Initialize Dlib detector and predictor
66
+ try:
67
+ detector = dlib.get_frontal_face_detector()
68
+ predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
69
+ except Exception as e:
70
+ print(f"Error initializing Dlib: {e}")
71
+ raise RuntimeError("Failed to initialize Dlib.")
72
 
73
  # Class mapping for RAF-DB
74
  class_mapping = {
 
149
  features.append(angle)
150
 
151
  mouth_center = ((key_points['mouth_left'][0] + key_points['mouth_right'][0]) / 2,
152
+ (key_points['mouth_left'][1] - key_points['mouth_right'][1]) / 2)
153
  mouth_to_left_eye = np.sqrt((mouth_center[0] - key_points['left_eye'][0])**2 +
154
  (mouth_center[1] - key_points['left_eye'][1])**2)
155
  mouth_to_right_eye = np.sqrt((mouth_center[0] - key_points['right_eye'][0])**2 +
 
212
  mask = np.clip(mask, 0, 1)
213
  return mask
214
 
215
+ # Model definitions (unchanged)
216
  class EfficientNetBackbone(nn.Module):
217
  def __init__(self):
218
  super(EfficientNetBackbone, self).__init__()
 
392
 
393
  # Load model
394
  model = QuadStreamHLAViT(num_classes=7)
395
+ try:
396
+ model.load_state_dict(torch.load(MODEL_WEIGHTS_PATH, map_location=torch.device('cpu'), weights_only=True))
397
+ print("Model weights loaded successfully.")
398
+ except Exception as e:
399
+ print(f"Error loading model weights: {e}")
400
+ raise RuntimeError("Failed to load model weights.")
 
 
401
  model.eval()
402
 
403
  # Inference function