DumbledoreWiz commited on
Commit
513f08f
·
verified ·
1 Parent(s): e8839a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -2,22 +2,34 @@ import torch
2
  from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
3
  import gradio as gr
4
  from PIL import Image
 
5
 
6
-
7
- print("p1")
8
  # Define the class labels as used during training
9
  labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos']
10
 
11
- # Load the configuration
12
- config = ViTConfig.from_pretrained("DumbledoreWiz/PantsShape")
 
 
 
 
 
 
 
 
 
13
 
14
- # Load the ViT model with the configuration
15
- model = ViTForImageClassification.from_pretrained("DumbledoreWiz/PantsShape", config=config)
16
- feature_extractor = ViTFeatureExtractor.from_pretrained("DumbledoreWiz/PantsShape")
17
- print("Model loaded")
18
- # Set the model to evaluation mode
19
  model.eval()
20
 
 
 
 
 
 
 
21
  # Define the prediction function
22
  def predict(image):
23
  # Preprocess the image
 
2
  from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig
3
  import gradio as gr
4
  from PIL import Image
5
+ import os
6
 
 
 
7
  # Define the class labels as used during training
8
  labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos']
9
 
10
+ # Define the path to the uploaded model file
11
+ model_path = "final_fine_tuned_vit_Leggings_Jogger_Palazzo_Cargo_Dresspants_Chinos_2024-08-14.pth"
12
+
13
+ # Check if config.json exists, otherwise use default config
14
+ if os.path.exists("config.json"):
15
+ config = ViTConfig.from_pretrained(".")
16
+ else:
17
+ config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k")
18
+ config.num_labels = len(labels)
19
+ config.id2label = {str(i): label for i, label in enumerate(labels)}
20
+ config.label2id = {label: str(i) for i, label in enumerate(labels)}
21
 
22
+ # Load the model
23
+ model = ViTForImageClassification(config)
24
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
 
 
25
  model.eval()
26
 
27
+ # Load or create feature extractor
28
+ if os.path.exists("preprocessor_config.json"):
29
+ feature_extractor = ViTFeatureExtractor.from_pretrained(".")
30
+ else:
31
+ feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
32
+
33
  # Define the prediction function
34
  def predict(image):
35
  # Preprocess the image