AamirMalik commited on
Commit
a5dbef3
·
verified ·
1 Parent(s): 3e8895b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -4,10 +4,11 @@ import mediapipe as mp
4
  import numpy as np
5
  import tensorflow as tf
6
  import tempfile
7
- from transformers import AutoModelForImageClassification
8
 
9
- # Load gesture classification model from Hugging Face Hub
10
- model = AutoModelForImageClassification.from_pretrained("nateraw/gesture-classification")
 
11
 
12
  # Mediapipe initialization
13
  mp_hands = mp.solutions.hands
@@ -15,10 +16,11 @@ hands = mp_hands.Hands()
15
  mp_draw = mp.solutions.drawing_utils
16
 
17
  # Function for gesture classification
18
- def classify_gesture(landmarks):
19
- landmarks = np.array(landmarks).reshape(1, -1)
20
- prediction = model(landmarks)
21
- return np.argmax(prediction.logits.detach().numpy())
 
22
 
23
  # Streamlit UI
24
  def main():
@@ -48,10 +50,8 @@ def main():
48
  for hand_landmarks in results.multi_hand_landmarks:
49
  mp_draw.draw_landmarks(frame, hand_landmarks, mp_hands.HAND_CONNECTIONS)
50
 
51
- # Extract landmarks
52
- landmarks = [landmark.x for landmark in hand_landmarks.landmark]
53
- landmarks += [landmark.y for landmark in hand_landmarks.landmark]
54
- gesture = classify_gesture(landmarks)
55
  st.write(f"Gesture: {gesture}")
56
 
57
  frame_placeholder.image(frame, channels="RGB")
 
4
  import numpy as np
5
  import tensorflow as tf
6
  import tempfile
7
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
8
 
9
+ # Load gesture classification model from Hugging Face Hub (public model)
10
+ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
11
+ model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
12
 
13
  # Mediapipe initialization
14
  mp_hands = mp.solutions.hands
 
16
  mp_draw = mp.solutions.drawing_utils
17
 
18
  # Function for gesture classification
19
+ def classify_gesture(image):
20
+ inputs = processor(images=image, return_tensors="pt")
21
+ outputs = model(**inputs)
22
+ prediction = outputs.logits.argmax(-1).item()
23
+ return prediction
24
 
25
  # Streamlit UI
26
  def main():
 
50
  for hand_landmarks in results.multi_hand_landmarks:
51
  mp_draw.draw_landmarks(frame, hand_landmarks, mp_hands.HAND_CONNECTIONS)
52
 
53
+ # Gesture classification
54
+ gesture = classify_gesture(frame)
 
 
55
  st.write(f"Gesture: {gesture}")
56
 
57
  frame_placeholder.image(frame, channels="RGB")