DumbledoreWiz commited on
Commit
1b79fe8
·
verified ·
1 Parent(s): 3713f52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  from PIL import Image
5
  import os
6
  import logging
 
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -34,11 +35,12 @@ config = AutoConfig.from_pretrained(config_path)
34
  # Load the feature extractor
35
  feature_extractor = ViTFeatureExtractor.from_pretrained(preprocessor_path)
36
 
37
- # Load the model using the specific paths
 
38
  model = ViTForImageClassification.from_pretrained(
39
  pretrained_model_name_or_path=None,
40
  config=config,
41
- state_dict=torch.load(model_path, map_location=torch.device('cpu'))
42
  )
43
 
44
  # Ensure the model is in evaluation mode
 
4
  from PIL import Image
5
  import os
6
  import logging
7
+ from safetensors.torch import load_file # Import safetensors loading function
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
35
  # Load the feature extractor
36
  feature_extractor = ViTFeatureExtractor.from_pretrained(preprocessor_path)
37
 
38
+ # Load the model using the safetensors file
39
+ state_dict = load_file(model_path) # Use safetensors to load the model weights
40
  model = ViTForImageClassification.from_pretrained(
41
  pretrained_model_name_or_path=None,
42
  config=config,
43
+ state_dict=state_dict
44
  )
45
 
46
  # Ensure the model is in evaluation mode