DumbledoreWiz commited on
Commit
7132357
·
verified ·
1 Parent(s): 513f08f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -13
app.py CHANGED
@@ -3,48 +3,93 @@ from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConf
3
  import gradio as gr
4
  from PIL import Image
5
  import os
 
 
 
 
6
 
7
  # Define the class labels as used during training
8
  labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos']
 
9
 
10
  # Define the path to the uploaded model file
11
  model_path = "final_fine_tuned_vit_Leggings_Jogger_Palazzo_Cargo_Dresspants_Chinos_2024-08-14.pth"
 
12
 
13
- # Check if config.json exists, otherwise use default config
14
- if os.path.exists("config.json"):
15
- config = ViTConfig.from_pretrained(".")
16
  else:
17
- config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k")
18
- config.num_labels = len(labels)
19
- config.id2label = {str(i): label for i, label in enumerate(labels)}
20
- config.label2id = {label: str(i) for i, label in enumerate(labels)}
 
 
 
 
 
21
 
22
- # Load the model
 
23
  model = ViTForImageClassification(config)
24
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  model.eval()
 
26
 
27
  # Load or create feature extractor
28
- if os.path.exists("preprocessor_config.json"):
29
- feature_extractor = ViTFeatureExtractor.from_pretrained(".")
30
- else:
31
- feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
32
 
33
  # Define the prediction function
34
  def predict(image):
 
 
 
35
  # Preprocess the image
 
36
  inputs = feature_extractor(images=image, return_tensors="pt")
 
37
 
 
38
  with torch.no_grad():
39
  outputs = model(**inputs)
40
  logits = outputs.logits
41
  probabilities = torch.nn.functional.softmax(logits[0], dim=0)
42
 
 
 
 
43
  # Prepare the output dictionary
44
  result = {labels[i]: float(probabilities[i]) for i in range(len(labels))}
 
 
45
  return result
46
 
47
  # Set up the Gradio Interface
 
48
  gradio_app = gr.Interface(
49
  fn=predict,
50
  inputs=gr.Image(type="pil"),
@@ -54,4 +99,5 @@ gradio_app = gr.Interface(
54
 
55
  # Launch the app
56
  if __name__ == "__main__":
 
57
  gradio_app.launch()
 
3
  import gradio as gr
4
  from PIL import Image
5
  import os
6
+ import logging
7
+
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
 
11
  # Define the class labels as used during training
12
  labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos']
13
+ logging.info(f"Labels: {labels}")
14
 
15
  # Define the path to the uploaded model file
16
  model_path = "final_fine_tuned_vit_Leggings_Jogger_Palazzo_Cargo_Dresspants_Chinos_2024-08-14.pth"
17
+ logging.info(f"Looking for model file: {model_path}")
18
 
19
+ if os.path.exists(model_path):
20
+ logging.info(f"Model file found: {model_path}")
 
21
  else:
22
+ logging.error(f"Model file not found: {model_path}")
23
+ raise FileNotFoundError(f"Model file not found: {model_path}")
24
+
25
+ # Create a custom configuration
26
+ config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k")
27
+ config.num_labels = len(labels)
28
+ config.id2label = {str(i): label for i, label in enumerate(labels)}
29
+ config.label2id = {label: str(i) for i, label in enumerate(labels)}
30
+ logging.info(f"Custom config created with {len(labels)} labels")
31
 
32
+ # Load the model with the custom configuration
33
+ logging.info("Loading the model with custom configuration")
34
  model = ViTForImageClassification(config)
35
+
36
+ try:
37
+ # Load the state dict
38
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
39
+
40
+ # Check if the state dict keys match the model's keys
41
+ model_keys = set(model.state_dict().keys())
42
+ loaded_keys = set(state_dict.keys())
43
+
44
+ if model_keys != loaded_keys:
45
+ logging.warning("Mismatch in state dict keys. Attempting to adjust...")
46
+ # Adjust keys if necessary (e.g., remove 'module.' prefix if it exists)
47
+ new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
48
+ model.load_state_dict(new_state_dict)
49
+ else:
50
+ model.load_state_dict(state_dict)
51
+
52
+ logging.info("Model loaded successfully")
53
+ except Exception as e:
54
+ logging.error(f"Error loading model: {str(e)}")
55
+ raise
56
+
57
  model.eval()
58
+ logging.info("Model set to evaluation mode")
59
 
60
  # Load or create feature extractor
61
+ feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
62
+ logging.info("Feature extractor loaded")
63
+
64
+ logging.info("Model and feature extractor loaded successfully")
65
 
66
  # Define the prediction function
67
  def predict(image):
68
+ logging.info("Starting prediction")
69
+ logging.info(f"Input image shape: {image.size}")
70
+
71
  # Preprocess the image
72
+ logging.info("Preprocessing image")
73
  inputs = feature_extractor(images=image, return_tensors="pt")
74
+ logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}")
75
 
76
+ logging.info("Running inference")
77
  with torch.no_grad():
78
  outputs = model(**inputs)
79
  logits = outputs.logits
80
  probabilities = torch.nn.functional.softmax(logits[0], dim=0)
81
 
82
+ logging.info(f"Raw logits: {logits}")
83
+ logging.info(f"Probabilities: {probabilities}")
84
+
85
  # Prepare the output dictionary
86
  result = {labels[i]: float(probabilities[i]) for i in range(len(labels))}
87
+ logging.info(f"Prediction result: {result}")
88
+
89
  return result
90
 
91
  # Set up the Gradio Interface
92
+ logging.info("Setting up Gradio interface")
93
  gradio_app = gr.Interface(
94
  fn=predict,
95
  inputs=gr.Image(type="pil"),
 
99
 
100
  # Launch the app
101
  if __name__ == "__main__":
102
+ logging.info("Launching the app")
103
  gradio_app.launch()