reab5555 commited on
Commit
dde67a3
·
verified ·
1 Parent(s): 6acf04f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -21
app.py CHANGED
@@ -27,8 +27,11 @@ def process_frame(frame, selected_model):
27
  model, preprocess = clip.load(selected_model, device=device)
28
  current_model_name = selected_model
29
 
30
- # Convert frame to RGB
31
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
 
 
32
 
33
  # Detect faces
34
  boxes, _ = mtcnn.detect(Image.fromarray(frame_rgb))
@@ -41,7 +44,7 @@ def process_frame(frame, selected_model):
41
  # Process the largest face
42
  if largest_face is not None:
43
  x, y, w, h = map(int, largest_face)
44
- cv2.rectangle(frame_rgb, (x, y), (w, h), (0, 0, 255), 2)
45
  cropped_face = frame_rgb[y:h, x:w]
46
 
47
  # Convert the cropped face to a tensor
@@ -69,47 +72,43 @@ def process_frame(frame, selected_model):
69
  top_five_labels_probs = combined_labels_probs[:5]
70
 
71
  # Create a bar graph
72
- fig, ax = plt.subplots(figsize=(10, 5), dpi=300) # Increased figure size
73
- plt.subplots_adjust(left=0.3) # Adjust left margin
74
 
75
  labels, probs = zip(*top_five_labels_probs)
76
  bars = ax.barh(labels, probs, color=plt.cm.tab20.colors)
77
  ax.set_xlabel('Probability')
78
  ax.set_title('Top 5 Emotions')
79
- ax.invert_yaxis() # Invert y-axis to have the highest probability at the top
80
 
81
- # Adjust x-axis labels to show only 3 decimal places
82
  ax.set_xticks(ax.get_xticks())
83
  ax.set_xticklabels([f'{x:.3f}' for x in ax.get_xticks()], rotation=0, ha='center')
84
 
85
- # Ensure all labels are fully visible
86
  plt.tight_layout()
87
 
88
- # Convert plot to image
89
  fig.canvas.draw()
90
  plot_img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
91
  plot_img = plot_img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
92
 
93
  plt.close(fig)
94
 
95
- return frame_rgb, frame_rgb, plot_img
96
  except Exception as e:
97
  print(f"An error occurred: {str(e)}")
98
- return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), frame, None
99
 
100
  def process_video(input_video, selected_model, frame_index):
101
  try:
102
  cap = cv2.VideoCapture(input_video)
103
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
104
 
105
- # Set the frame position
106
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
107
 
108
  ret, frame = cap.read()
109
  if not ret:
110
  return None, None
111
 
112
- processed_frame, _, graph = process_frame(frame, selected_model)
113
  cap.release()
114
 
115
  return processed_frame, graph
@@ -158,14 +157,7 @@ with gr.Blocks() as app:
158
  output_image_graph = gr.Image(label="Results Graph")
159
 
160
  def process_image(image, model):
161
- frame_rgb, _, graph = process_frame(image, model)
162
- # Ensure the frame is in RGB format
163
- if len(frame_rgb.shape) == 3 and frame_rgb.shape[2] == 3:
164
- if frame_rgb.dtype != np.uint8:
165
- frame_rgb = (frame_rgb * 255).astype(np.uint8)
166
- else:
167
- # If the image is not in the correct format, convert it to RGB
168
- frame_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
169
  return frame_rgb, graph
170
 
171
  process_button_image.click(process_image, inputs=[image_input, model_dropdown_image], outputs=[output_image, output_image_graph])
 
27
  model, preprocess = clip.load(selected_model, device=device)
28
  current_model_name = selected_model
29
 
30
+ # Convert frame to RGB if it's not already
31
+ if len(frame.shape) == 3 and frame.shape[2] == 3:
32
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
33
+ else:
34
+ frame_rgb = frame
35
 
36
  # Detect faces
37
  boxes, _ = mtcnn.detect(Image.fromarray(frame_rgb))
 
44
  # Process the largest face
45
  if largest_face is not None:
46
  x, y, w, h = map(int, largest_face)
47
+ cv2.rectangle(frame_rgb, (x, y), (w, h), (0, 255, 0), 2)
48
  cropped_face = frame_rgb[y:h, x:w]
49
 
50
  # Convert the cropped face to a tensor
 
72
  top_five_labels_probs = combined_labels_probs[:5]
73
 
74
  # Create a bar graph
75
+ fig, ax = plt.subplots(figsize=(10, 5), dpi=300)
76
+ plt.subplots_adjust(left=0.3)
77
 
78
  labels, probs = zip(*top_five_labels_probs)
79
  bars = ax.barh(labels, probs, color=plt.cm.tab20.colors)
80
  ax.set_xlabel('Probability')
81
  ax.set_title('Top 5 Emotions')
82
+ ax.invert_yaxis()
83
 
 
84
  ax.set_xticks(ax.get_xticks())
85
  ax.set_xticklabels([f'{x:.3f}' for x in ax.get_xticks()], rotation=0, ha='center')
86
 
 
87
  plt.tight_layout()
88
 
 
89
  fig.canvas.draw()
90
  plot_img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
91
  plot_img = plot_img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
92
 
93
  plt.close(fig)
94
 
95
+ return frame_rgb, plot_img
96
  except Exception as e:
97
  print(f"An error occurred: {str(e)}")
98
+ return frame_rgb, None
99
 
100
  def process_video(input_video, selected_model, frame_index):
101
  try:
102
  cap = cv2.VideoCapture(input_video)
103
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
104
 
 
105
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
106
 
107
  ret, frame = cap.read()
108
  if not ret:
109
  return None, None
110
 
111
+ processed_frame, graph = process_frame(frame, selected_model)
112
  cap.release()
113
 
114
  return processed_frame, graph
 
157
  output_image_graph = gr.Image(label="Results Graph")
158
 
159
  def process_image(image, model):
160
+ frame_rgb, graph = process_frame(image, model)
 
 
 
 
 
 
 
161
  return frame_rgb, graph
162
 
163
  process_button_image.click(process_image, inputs=[image_input, model_dropdown_image], outputs=[output_image, output_image_graph])