DumbledoreWiz commited on
Commit
3713f52
·
verified ·
1 Parent(s): 85dcc91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import ViTForImageClassification, ViTFeatureExtractor
3
  import gradio as gr
4
  from PIL import Image
5
  import os
@@ -12,8 +12,10 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
12
  labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos']
13
  logging.info(f"Labels: {labels}")
14
 
15
- # Define paths to the model files (all in the same directory as app.py)
16
  model_dir = "." # Use current directory
 
 
17
  model_path = os.path.join(model_dir, "model.safetensors")
18
  config_path = os.path.join(model_dir, "config.json")
19
  preprocessor_path = os.path.join(model_dir, "preprocessor_config.json")
@@ -26,11 +28,18 @@ for path in [model_path, config_path, preprocessor_path]:
26
  else:
27
  logging.info(f"Found file: {path}")
28
 
29
- # Load the model and feature extractor using the local directory
30
- model_id = "google/vit-base-patch16-224"
 
 
 
31
 
32
- feature_extractor = ViTFeatureExtractor.from_pretrained(model_id)
33
- model = ViTForImageClassification.from_pretrained(model_path)
 
 
 
 
34
 
35
  # Ensure the model is in evaluation mode
36
  model.eval()
 
1
  import torch
2
+ from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoConfig
3
  import gradio as gr
4
  from PIL import Image
5
  import os
 
12
  labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos']
13
  logging.info(f"Labels: {labels}")
14
 
15
+ # Define the directory containing the model files
16
  model_dir = "." # Use current directory
17
+
18
+ # Define paths to the specific model files
19
  model_path = os.path.join(model_dir, "model.safetensors")
20
  config_path = os.path.join(model_dir, "config.json")
21
  preprocessor_path = os.path.join(model_dir, "preprocessor_config.json")
 
28
  else:
29
  logging.info(f"Found file: {path}")
30
 
31
+ # Load the configuration
32
+ config = AutoConfig.from_pretrained(config_path)
33
+
34
+ # Load the feature extractor
35
+ feature_extractor = ViTFeatureExtractor.from_pretrained(preprocessor_path)
36
 
37
+ # Load the model using the specific paths
38
+ model = ViTForImageClassification.from_pretrained(
39
+ pretrained_model_name_or_path=None,
40
+ config=config,
41
+ state_dict=torch.load(model_path, map_location=torch.device('cpu'))
42
+ )
43
 
44
  # Ensure the model is in evaluation mode
45
  model.eval()