Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoConfig | |
| import gradio as gr | |
| from PIL import Image | |
| import os | |
| import logging | |
| from safetensors.torch import load_file # Import safetensors loading function | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Define the directory containing the model files | |
| model_dir = "." # Use current directory | |
| # Define paths to the specific model files | |
| model_path = os.path.join(model_dir, "model.safetensors") | |
| config_path = os.path.join(model_dir, "config.json") | |
| preprocessor_path = os.path.join(model_dir, "preprocessor_config.json") | |
| # Check if all required files exist | |
| for path in [model_path, config_path, preprocessor_path]: | |
| if not os.path.exists(path): | |
| logging.error(f"File not found: {path}") | |
| raise FileNotFoundError(f"Required file not found: {path}") | |
| else: | |
| logging.info(f"Found file: {path}") | |
| # Load the configuration | |
| config = AutoConfig.from_pretrained(config_path) | |
| # Ensure the labels are consistent with the model's config | |
| labels = list(config.id2label.values()) | |
| logging.info(f"Labels: {labels}") | |
| # Load the feature extractor | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(preprocessor_path) | |
| # Load the model using the safetensors file | |
| state_dict = load_file(model_path) # Use safetensors to load the model weights | |
| model = ViTForImageClassification.from_pretrained( | |
| pretrained_model_name_or_path=None, | |
| config=config, | |
| state_dict=state_dict | |
| ) | |
| # Ensure the model is in evaluation mode | |
| model.eval() | |
| logging.info("Model set to evaluation mode") | |
| # Define the prediction function | |
| def predict(image): | |
| logging.info("Starting prediction") | |
| logging.info(f"Input image shape: {image.size}") | |
| # Preprocess the image | |
| logging.info("Preprocessing image") | |
| inputs = feature_extractor(images=image, return_tensors="pt") | |
| logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}") | |
| logging.info("Running inference") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.nn.functional.softmax(logits[0], dim=0) | |
| logging.info(f"Raw logits: {logits}") | |
| logging.info(f"Probabilities: {probabilities}") | |
| # Prepare the output dictionary | |
| result = {labels[i]: float(probabilities[i]) for i in range(len(labels))} | |
| logging.info(f"Prediction result: {result}") | |
| return result | |
| # Set up the Gradio Interface | |
| logging.info("Setting up Gradio interface") | |
| gradio_app = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=6), | |
| title="Pattern Placement Classifier" | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| logging.info("Launching the app") | |
| gradio_app.launch() | |