prakasa1234 commited on
Commit
5c5fe28
·
verified ·
1 Parent(s): 0a0b7f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -33
app.py CHANGED
@@ -8,13 +8,14 @@ import gradio as gr
8
  from mediapipe import Image as MPImage
9
  from mediapipe.tasks import python
10
  from mediapipe.tasks.python import vision
 
11
 
12
  # -----------------------------
13
  # 1. Paths & URLs
14
  # -----------------------------
15
  HAND_MODEL_PATH = "hand_landmarker.task"
16
  HAND_MODEL_URL = "https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/1/hand_landmarker.task"
17
- YOLO_MODEL_PATH = "yolov11n_finetuned_ASL.pt" # Already in repo via Git LFS or small enough
18
 
19
  # -----------------------------
20
  # 2. Download MediaPipe model if missing
@@ -39,44 +40,51 @@ hand_options = vision.HandLandmarkerOptions(base_options=base_options, num_hands
39
  detector = vision.HandLandmarker.create_from_options(hand_options)
40
 
41
  # -----------------------------
42
- # 4. Inference function
43
  # -----------------------------
44
  def predict_asl(image):
45
- """
46
- Input: numpy array (H x W x 3) from Gradio
47
- Output: annotated image, predicted class, confidence
48
- """
49
- img = image.copy()
50
- h, w, _ = img.shape
 
51
 
52
- # --- Annotate hand landmarks ---
53
- mp_image = MPImage.create_from_array(img)
54
- detection_result = detector.detect(mp_image)
55
- if detection_result.hand_landmarks:
56
- for hand_landmarks in detection_result.hand_landmarks:
57
- for landmark in hand_landmarks:
58
- x, y = int(landmark.x * w), int(landmark.y * h)
59
- cv2.circle(img, (x, y), 3, (0, 255, 0), -1)
60
 
61
- # --- YOLO prediction ---
62
- results = yolo_model.predict(img, imgsz=300, verbose=False)[0]
63
- pred_idx = results.probs.top1
64
- pred_label = results.names[pred_idx]
65
- confidence = results.probs.data[pred_idx].item()
66
 
67
- # Overlay prediction text
68
- cv2.putText(
69
- img,
70
- f"{pred_label} ({confidence:.2f})",
71
- (10, 30),
72
- cv2.FONT_HERSHEY_SIMPLEX,
73
- 1,
74
- (0, 0, 255),
75
- 2,
76
- cv2.LINE_AA
77
- )
78
 
79
- return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), pred_label, round(confidence, 2)
 
 
 
 
 
 
80
 
81
  # -----------------------------
82
  # 5. Gradio Interface
 
8
  from mediapipe import Image as MPImage
9
  from mediapipe.tasks import python
10
  from mediapipe.tasks.python import vision
11
+ import traceback
12
 
13
  # -----------------------------
14
  # 1. Paths & URLs
15
  # -----------------------------
16
  HAND_MODEL_PATH = "hand_landmarker.task"
17
  HAND_MODEL_URL = "https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/1/hand_landmarker.task"
18
+ YOLO_MODEL_PATH = "yolov11n_finetuned_ASL.pt"
19
 
20
  # -----------------------------
21
  # 2. Download MediaPipe model if missing
 
40
  detector = vision.HandLandmarker.create_from_options(hand_options)
41
 
42
  # -----------------------------
43
+ # 4. Inference function with robust error handling
44
  # -----------------------------
45
  def predict_asl(image):
46
+ try:
47
+ if image is None:
48
+ raise ValueError("No image provided")
49
+
50
+ img = image.copy()
51
+ h, w, _ = img.shape
52
+ print(f"🔹 Uploaded image shape: {img.shape}, dtype: {img.dtype}")
53
 
54
+ # --- Annotate hand landmarks ---
55
+ mp_image = MPImage.create_from_array(img)
56
+ detection_result = detector.detect(mp_image)
57
+ if detection_result.hand_landmarks:
58
+ for hand_landmarks in detection_result.hand_landmarks:
59
+ for landmark in hand_landmarks:
60
+ x, y = int(landmark.x * w), int(landmark.y * h)
61
+ cv2.circle(img, (x, y), 3, (0, 255, 0), -1)
62
 
63
+ # --- YOLO prediction ---
64
+ results = yolo_model.predict(img, imgsz=300, verbose=False)[0]
65
+ pred_idx = results.probs.top1
66
+ pred_label = results.names[pred_idx]
67
+ confidence = results.probs.data[pred_idx].item()
68
 
69
+ # Overlay prediction text
70
+ cv2.putText(
71
+ img,
72
+ f"{pred_label} ({confidence:.2f})",
73
+ (10, 30),
74
+ cv2.FONT_HERSHEY_SIMPLEX,
75
+ 1,
76
+ (0, 0, 255),
77
+ 2,
78
+ cv2.LINE_AA
79
+ )
80
 
81
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), pred_label, round(confidence, 2)
82
+
83
+ except Exception as e:
84
+ print("❌ Error in predict_asl:", e)
85
+ traceback.print_exc()
86
+ # Return original image and error placeholders
87
+ return image, "Error", 0.0
88
 
89
  # -----------------------------
90
  # 5. Gradio Interface