DumbledoreWiz commited on
Commit
4b0d68d
·
verified ·
1 Parent(s): 672d7ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -34
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
3
  import gradio as gr
4
  from PIL import Image
5
  import os
@@ -8,49 +8,34 @@ import logging
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
 
11
- # Define the class labels in the correct order as used during training
12
  labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos']
13
  logging.info(f"Labels: {labels}")
14
 
15
- # Define the path to the uploaded model file
16
- model_path = "best_fine_tuned_vit_Leggings_Jogger_Palazzo_Cargo_Dresspants_Chinos_93.90243902439025_2024-08-26.pth"
17
- logging.info(f"Looking for model file: {model_path}")
 
18
 
19
- if os.path.exists(model_path):
20
- logging.info(f"Model file found: {model_path}")
21
- else:
22
- logging.error(f"Model file not found: {model_path}")
23
- raise FileNotFoundError(f"Model file not found: {model_path}")
 
 
24
 
25
- # Create label mappings consistent with training
26
- id2label = {str(i): label for i, label in enumerate(labels)}
27
- label2id = {label: str(i) for i, label in enumerate(labels)}
28
-
29
- # Create a configuration for the model
30
- config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k")
31
- config.num_labels = len(labels)
32
- config.id2label = id2label
33
- config.label2id = label2id
34
 
35
  # Initialize the model with the configuration
36
- model = ViTForImageClassification(config)
37
-
38
- try:
39
- # Load the state dict of the fine-tuned model
40
- state_dict = torch.load(model_path, map_location=torch.device('cpu'))
41
- model.load_state_dict(state_dict)
42
- logging.info("Fine-tuned model loaded successfully")
43
- except Exception as e:
44
- logging.error(f"Error loading model: {str(e)}")
45
- raise
46
 
 
47
  model.eval()
48
  logging.info("Model set to evaluation mode")
49
 
50
- # Load feature extractor
51
- feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
52
- logging.info("Feature extractor loaded")
53
-
54
  # Define the prediction function
55
  def predict(image):
56
  logging.info("Starting prediction")
@@ -88,4 +73,4 @@ gradio_app = gr.Interface(
88
  # Launch the app
89
  if __name__ == "__main__":
90
  logging.info("Launching the app")
91
- gradio_app.launch()
 
1
  import torch
2
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
3
  import gradio as gr
4
  from PIL import Image
5
  import os
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
 
11
+ # Define the labels in the correct order as used during training
12
  labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos']
13
  logging.info(f"Labels: {labels}")
14
 
15
+ # Define paths to the model files (all in the same directory as app.py)
16
+ model_path = "model.safetensors"
17
+ config_path = "config.json"
18
+ preprocessor_path = "preprocessor_config.json"
19
 
20
+ # Check if all required files exist
21
+ for path in [model_path, config_path, preprocessor_path]:
22
+ if not os.path.exists(path):
23
+ logging.error(f"File not found: {path}")
24
+ raise FileNotFoundError(f"Required file not found: {path}")
25
+ else:
26
+ logging.info(f"Found file: {path}")
27
 
28
+ # Load the configuration and feature extractor
29
+ config = ViTForImageClassification.from_pretrained(".", config=config_path)
30
+ feature_extractor = ViTFeatureExtractor.from_pretrained(".")
 
 
 
 
 
 
31
 
32
  # Initialize the model with the configuration
33
+ model = ViTForImageClassification.from_pretrained(".", config=config)
 
 
 
 
 
 
 
 
 
34
 
35
+ # Ensure the model is in evaluation mode
36
  model.eval()
37
  logging.info("Model set to evaluation mode")
38
 
 
 
 
 
39
  # Define the prediction function
40
  def predict(image):
41
  logging.info("Starting prediction")
 
73
  # Launch the app
74
  if __name__ == "__main__":
75
  logging.info("Launching the app")
76
+ gradio_app.launch()