File size: 7,374 Bytes
d8fdc96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import os
import cv2
import numpy as np
from typing import Optional, Tuple
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.efficientnet import preprocess_input
from huggingface_hub import hf_hub_download

# Global model cache for lazy loading
_model_cache = {}

# Model configurations for different EfficientNet versions
MODEL_CONFIGS = {
    "EfficientNetB4": {
        "repo_id": "d2j666/asl-efficientnets",  # UPDATE THIS!
        "filename": "efficientnetb4_asl.h5",
        "input_size": (224, 224),
        "classes": 29,
        "description": "EfficientNetB4 - Balanced performance and speed"
    },
    "EfficientNetB7": {
        "repo_id": "d2j666/asl-efficientnets",  # UPDATE THIS!
        "filename": "efficientnetb7_asl.h5",
        "input_size": (224, 224),
        "classes": 29,
        "description": "EfficientNetB7 - Higher accuracy, slower inference"
    },
    "EfficientNetB9": {
        "repo_id": "d2j666/asl-efficientnets",  # UPDATE THIS!
        "filename": "efficientnetb9_asl.h5",
        "input_size": (224, 224),
        "classes": 29,
        "description": "EfficientNetB9 - Highest accuracy, slowest inference"
    }
}


class ASLDetectorML:
    """
    ASL hand gesture detection using trained EfficientNet models.

    This detector uses deep learning models trained on the ASL Alphabet dataset
    to classify 29 different gestures (A-Z, del, nothing, space).
    """

    # ASL class labels (29 total)
    LABELS = [
        'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
        'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
        'U', 'V', 'W', 'X', 'Y', 'Z',
        'del', 'nothing', 'space'
    ]

    def __init__(self, model_name: str = "EfficientNetB4"):
        """
        Initialize the ML-based ASL detector.

        Args:
            model_name: Name of the model to use ("EfficientNetB4", "EfficientNetB7", or "EfficientNetB9")
        """
        if model_name not in MODEL_CONFIGS:
            raise ValueError(f"Model {model_name} not found. Available models: {list(MODEL_CONFIGS.keys())}")

        self.model_name = model_name
        self.config = MODEL_CONFIGS[model_name]
        self.model = None
        self.input_size = self.config["input_size"]

        print(f"[INFO] Initializing {model_name} detector...")
        self._load_model()

    def _load_model(self):
        """Load model from HuggingFace Hub with caching."""
        global _model_cache

        # Check if model is already cached in memory
        if self.model_name in _model_cache:
            print(f"[INFO] Loading {self.model_name} from memory cache")
            self.model = _model_cache[self.model_name]
            return

        try:
            print(f"[INFO] Downloading {self.model_name} from HuggingFace Hub...")
            print(f"[INFO] This may take 5-10 seconds on first load...")

            # Download model from HuggingFace Hub
            model_path = hf_hub_download(
                repo_id=self.config["repo_id"],
                filename=self.config["filename"],
                cache_dir="./models",  # Local cache directory
                token=os.environ.get("HF_TOKEN")  # Optional: for private repos
            )

            print(f"[INFO] Model downloaded to: {model_path}")
            print(f"[INFO] Loading model into memory...")

            # Load the Keras model
            self.model = load_model(model_path)

            # Cache the model for future use
            _model_cache[self.model_name] = self.model

            print(f"[INFO] {self.model_name} loaded successfully!")

        except Exception as e:
            print(f"[ERROR] Failed to load model: {e}")
            print(f"[ERROR] Make sure models are uploaded to HuggingFace Hub")
            print(f"[ERROR] Expected repo: {self.config['repo_id']}")
            print(f"[ERROR] Expected file: {self.config['filename']}")
            raise

    def preprocess_image(self, image: np.ndarray) -> np.ndarray:
        """
        Preprocess image for EfficientNet model.

        Args:
            image: Input image as numpy array (RGB)

        Returns:
            Preprocessed image ready for model inference
        """
        # Resize to model's expected input size
        img = cv2.resize(image, self.input_size)

        # Convert BGR to RGB if needed
        if len(img.shape) == 3 and img.shape[2] == 3:
            # Assume it's already RGB from Gradio
            pass

        # Apply EfficientNet-specific preprocessing
        img = preprocess_input(img.astype(np.float32))

        # Add batch dimension
        img = np.expand_dims(img, axis=0)

        return img

    def predict(self, image: np.ndarray) -> Tuple[str, float]:
        """
        Predict ASL gesture from image.

        Args:
            image: Input image as numpy array (RGB)

        Returns:
            Tuple of (predicted_letter, confidence_score)
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call _load_model() first.")

        # Preprocess image
        preprocessed = self.preprocess_image(image)

        # Run inference
        predictions = self.model.predict(preprocessed, verbose=0)[0]

        # Get top prediction
        predicted_idx = np.argmax(predictions)
        confidence = float(predictions[predicted_idx])
        predicted_letter = self.LABELS[predicted_idx]

        return predicted_letter, confidence

    def process_frame(self, image: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[str], Optional[float]]:
        """
        Process a single frame for ASL classification.

        This method maintains compatibility with the existing ASLDetector interface.

        Args:
            image: RGB image array

        Returns:
            Tuple of (annotated_image, predicted_letter, confidence)
        """
        try:
            # Run prediction
            letter, confidence = self.predict(image)

            # Create annotated image with prediction
            annotated_image = image.copy()

            # Add text overlay
            if confidence > 0.3:  # Only show if reasonably confident
                text = f"{letter} ({confidence:.2f})"
                cv2.putText(
                    annotated_image,
                    text,
                    (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    1.0,
                    (0, 255, 0),
                    2
                )

            return annotated_image, letter, confidence

        except Exception as e:
            print(f"[ERROR] Prediction failed: {e}")
            return image, None, None

    def close(self):
        """Release resources. Models stay in cache for reuse."""
        print(f"[INFO] {self.model_name} detector closed (model remains in cache)")


def get_available_models():
    """Get list of available model names."""
    return list(MODEL_CONFIGS.keys())


def get_model_info(model_name: str) -> dict:
    """Get configuration info for a specific model."""
    if model_name not in MODEL_CONFIGS:
        raise ValueError(f"Model {model_name} not found")
    return MODEL_CONFIGS[model_name]


def clear_model_cache():
    """Clear the global model cache to free memory."""
    global _model_cache
    _model_cache.clear()
    print("[INFO] Model cache cleared")