Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig | |
| import gradio as gr | |
| from PIL import Image | |
| import os | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Define the class labels in the correct order as used during training | |
| labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos'] | |
| logging.info(f"Labels: {labels}") | |
| # Define the path to the uploaded model file | |
| model_path = "best_fine_tuned_vit_Leggings_Jogger_Palazzo_Cargo_Dresspants_Chinos_93.90243902439025_2024-08-26.pth" | |
| logging.info(f"Looking for model file: {model_path}") | |
| if os.path.exists(model_path): | |
| logging.info(f"Model file found: {model_path}") | |
| else: | |
| logging.error(f"Model file not found: {model_path}") | |
| raise FileNotFoundError(f"Model file not found: {model_path}") | |
| # Create label mappings consistent with training | |
| id2label = {str(i): label for i, label in enumerate(labels)} | |
| label2id = {label: str(i) for i, label in enumerate(labels)} | |
| # Create a configuration for the model | |
| config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k") | |
| config.num_labels = len(labels) | |
| config.id2label = id2label | |
| config.label2id = label2id | |
| # Initialize the model with the configuration | |
| model = ViTForImageClassification(config) | |
| try: | |
| # Load the state dict of the fine-tuned model | |
| state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
| model.load_state_dict(state_dict) | |
| logging.info("Fine-tuned model loaded successfully") | |
| except Exception as e: | |
| logging.error(f"Error loading model: {str(e)}") | |
| raise | |
| model.eval() | |
| logging.info("Model set to evaluation mode") | |
| # Load feature extractor | |
| feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") | |
| logging.info("Feature extractor loaded") | |
| # Define the prediction function | |
| def predict(image): | |
| logging.info("Starting prediction") | |
| logging.info(f"Input image shape: {image.size}") | |
| # Preprocess the image | |
| logging.info("Preprocessing image") | |
| inputs = feature_extractor(images=image, return_tensors="pt") | |
| logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}") | |
| logging.info("Running inference") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.nn.functional.softmax(logits[0], dim=0) | |
| logging.info(f"Raw logits: {logits}") | |
| logging.info(f"Probabilities: {probabilities}") | |
| # Prepare the output dictionary | |
| result = {labels[i]: float(probabilities[i]) for i in range(len(labels))} | |
| logging.info(f"Prediction result: {result}") | |
| return result | |
| # Set up the Gradio Interface | |
| logging.info("Setting up Gradio interface") | |
| gradio_app = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=6), | |
| title="Pants Shape Classifier" | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| logging.info("Launching the app") | |
| gradio_app.launch() |