AamirMalik commited on
Commit
48d0f4f
·
verified ·
1 Parent(s): f59b230

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -6,20 +6,25 @@ import torch
6
  import cv2
7
  import time
8
  import numpy as np
9
- from tensorflow.keras.models import load_model
10
 
11
- # Load the DeepASL model for live ASL alphabet classification
12
- MODEL_PATH = "asl_alphabet_model.h5"
13
- model = load_model(MODEL_PATH)
 
 
 
 
 
14
 
15
  # Function for ASL classification
16
  def classify_asl(image):
17
- image = image.resize((64, 64)) # Resize image to model input size
18
- image = np.array(image) / 255.0 # Normalize
19
- image = np.expand_dims(image, axis=0) # Add batch dimension
20
- prediction = model.predict(image)
21
  labels = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # ASL alphabet labels
22
- return labels[np.argmax(prediction)]
23
 
24
  # Streamlit UI
25
  def main():
 
6
  import cv2
7
  import time
8
  import numpy as np
9
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
10
 
11
+ # Load the open-source ASL model from Hugging Face
12
+ @st.cache_resource
13
+ def load_asl_model():
14
+ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
15
+ model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
16
+ return processor, model
17
+
18
+ processor, model = load_asl_model()
19
 
20
  # Function for ASL classification
21
  def classify_asl(image):
22
+ image = image.convert("RGB")
23
+ inputs = processor(images=image, return_tensors="pt")
24
+ outputs = model(**inputs)
25
+ prediction = torch.argmax(outputs.logits, dim=-1).item()
26
  labels = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # ASL alphabet labels
27
+ return labels[prediction % len(labels)]
28
 
29
  # Streamlit UI
30
  def main():