AamirMalik commited on
Commit
8801e16
·
verified ·
1 Parent(s): db8d5e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -11,11 +11,11 @@ from transformers import AutoImageProcessor, AutoModelForImageClassification
11
  # Set page config as the first command
12
  st.set_page_config(page_title="Sign Language Translator", layout="wide")
13
 
14
- # Load the open-source ASL model from Hugging Face
15
  @st.cache_resource
16
  def load_asl_model():
17
- processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
18
- model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
19
  return processor, model
20
 
21
  processor, model = load_asl_model()
@@ -27,7 +27,7 @@ def classify_asl(image):
27
  outputs = model(**inputs)
28
  prediction = torch.argmax(outputs.logits, dim=-1).item()
29
  labels = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # ASL alphabet labels
30
- return labels[prediction % len(labels)]
31
 
32
  # Streamlit UI
33
  def main():
 
11
  # Set page config as the first command
12
  st.set_page_config(page_title="Sign Language Translator", layout="wide")
13
 
14
+ # Load the ASL alphabet model from Hugging Face
15
  @st.cache_resource
16
  def load_asl_model():
17
+ processor = AutoImageProcessor.from_pretrained("Roboflow/ASL-Alphabet-Classifier")
18
+ model = AutoModelForImageClassification.from_pretrained("Roboflow/ASL-Alphabet-Classifier")
19
  return processor, model
20
 
21
  processor, model = load_asl_model()
 
27
  outputs = model(**inputs)
28
  prediction = torch.argmax(outputs.logits, dim=-1).item()
29
  labels = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # ASL alphabet labels
30
+ return labels[prediction] if prediction < len(labels) else "Unknown"
31
 
32
  # Streamlit UI
33
  def main():