reab5555 commited on
Commit
caba546
·
verified ·
1 Parent(s): d4deb5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -4
app.py CHANGED
@@ -19,13 +19,111 @@ current_model_name = "ViT-B/16"
19
  # Initialize MTCNN for face detection
20
  mtcnn = MTCNN(keep_all=True, device=device)
21
 
22
- # Process image function (same as before)
23
  def process_image(input_image, selected_model):
24
- # ... (keep the existing process_image function as is)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Process video function (same as before)
27
  def process_video(input_video, frame_number, selected_model):
28
- # ... (keep the existing process_video function as is)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def update_slider(video):
31
  if video is None:
 
19
  # Initialize MTCNN for face detection
20
  mtcnn = MTCNN(keep_all=True, device=device)
21
 
 
22
  def process_image(input_image, selected_model):
23
+ global model, preprocess, current_model_name
24
+
25
+ try:
26
+ # Load the selected model if it's different from the current one
27
+ if selected_model != current_model_name:
28
+ model, preprocess = clip.load(selected_model, device=device)
29
+ current_model_name = selected_model
30
+
31
+ # Convert input_image to numpy array
32
+ cv2_frame = np.array(input_image)
33
+ cv2_frame = cv2.cvtColor(cv2_frame, cv2.COLOR_RGB2BGR)
34
+
35
+ # Detect faces
36
+ frame_pil = Image.fromarray(cv2.cvtColor(cv2_frame, cv2.COLOR_BGR2RGB))
37
+ boxes, _ = mtcnn.detect(frame_pil)
38
+
39
+ # Find the largest face detected
40
+ largest_face = None
41
+ if boxes is not None and len(boxes) > 0:
42
+ largest_face = max(boxes, key=lambda box: (box[2] - box[0]) * (box[3] - box[1]))
43
+
44
+ # Process the largest face
45
+ if largest_face is not None:
46
+ x, y, w, h = map(int, largest_face)
47
+ cv2.rectangle(cv2_frame, (x, y), (w, h), (0, 0, 255), 2)
48
+ cropped_face = cv2_frame[y:h, x:w]
49
+
50
+ # Convert the cropped face to grayscale
51
+ frame_gray = cv2.cvtColor(cropped_face, cv2.COLOR_BGR2GRAY)
52
+ frame_gray_resized = cv2.resize(frame_gray, (160, 160))
53
+
54
+ # Convert the resized grayscale image to a tensor
55
+ frame_tensor = preprocess(Image.fromarray(frame_gray_resized)).unsqueeze(0).to(device)
56
+
57
+ # Tokenize input labels and prepare for model
58
+ input_labels = input_labels_X.split(", ")
59
+ text = clip.tokenize(input_labels).to(device)
60
+
61
+ with torch.no_grad():
62
+ # Encode the frame and text
63
+ image_features = model.encode_image(frame_tensor)
64
+ text_features = model.encode_text(text)
65
+
66
+ # Calculate logit
67
+ logit_per_image, logit_per_text = model(frame_tensor, text)
68
+
69
+ # Apply softmax to convert logits to probabilities
70
+ probabilities = F.softmax(logit_per_image[0], dim=0)
71
+
72
+ # Combine labels with probabilities and sort
73
+ combined_labels_probs = list(zip(input_labels, probabilities.tolist()))
74
+ combined_labels_probs.sort(key=lambda x: x[1], reverse=True)
75
+ top_five_labels_probs = combined_labels_probs[:5]
76
+
77
+ # Prepare results
78
+ results = []
79
+ for label, prob in top_five_labels_probs:
80
+ results.append(f"{label.strip()}: {prob * 100:.1f}%")
81
+
82
+ # Draw results on the image
83
+ for idx, result in enumerate(results):
84
+ cv2.putText(cv2_frame, result, (10, 30 + idx * 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
85
+
86
+ # Convert back to RGB for display
87
+ output_image = cv2.cvtColor(cv2_frame, cv2.COLOR_BGR2RGB)
88
+
89
+ # Create a bar plot with different colors
90
+ labels, probs = zip(*top_five_labels_probs)
91
+ fig, ax = plt.subplots(figsize=(10, 6))
92
+ colors = list(mcolors.TABLEAU_COLORS.values())[:5] # Get 5 distinct colors
93
+ ax.barh(labels, probs, color=colors)
94
+ ax.set_xlabel('Probability')
95
+ ax.set_title('Top Emotion Probabilities')
96
+ ax.set_xlim(0, max(probs) * 1.1) # Set x-axis limit to slightly larger than max probability
97
+ plt.tight_layout()
98
+
99
+ return output_image, "\n".join(results), fig
100
+ else:
101
+ return cv2_frame, "No face detected", None
102
+ except Exception as e:
103
+ return None, f"An error occurred: {str(e)}", None
104
 
 
105
  def process_video(input_video, frame_number, selected_model):
106
+ try:
107
+ # Load the video
108
+ cap = cv2.VideoCapture(input_video)
109
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
110
+
111
+ if frame_number >= total_frames:
112
+ return None, "Frame number exceeds total frames in the video", None
113
+
114
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
115
+ ret, frame = cap.read()
116
+ if not ret:
117
+ return None, "Error reading the frame", None
118
+
119
+ frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
120
+ processed_frame, results, fig = process_image(frame_pil, selected_model)
121
+ cap.release()
122
+
123
+ return processed_frame, results, fig
124
+ except Exception as e:
125
+ return None, f"An error occurred: {str(e)}", None
126
+
127
 
128
  def update_slider(video):
129
  if video is None: