Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torchvision.transforms as T | |
| from torchvision.models import resnet50 | |
| from torchvision.models.segmentation import deeplabv3_resnet50 | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def download_models(): | |
| """Download and save models to the artifacts directory""" | |
| # Define the artifacts directory | |
| artifacts_dir = os.path.join("prediction", "artifacts") | |
| # Create directory if it doesn't exist | |
| os.makedirs(artifacts_dir, exist_ok=True) | |
| logger.info(f"Artifacts directory: {artifacts_dir}") | |
| try: | |
| # Download ResNet50 Classification Model | |
| logger.info("Downloading ResNet50 model...") | |
| cls_model = resnet50(weights="IMAGENET1K_V1") | |
| cls_model.eval() | |
| resnet_path = os.path.join(artifacts_dir, "resnet50_imagenet.pth") | |
| torch.save(cls_model.state_dict(), resnet_path) | |
| logger.info(f"ResNet50 saved to: {resnet_path}") | |
| # Download DeepLabV3 Segmentation Model | |
| logger.info("Downloading DeepLabV3 model...") | |
| seg_model = deeplabv3_resnet50(weights="DEFAULT") | |
| seg_model.eval() | |
| deeplab_path = os.path.join(artifacts_dir, "deeplabv3_resnet50.pth") | |
| torch.save(seg_model.state_dict(), deeplab_path) | |
| logger.info(f"DeepLabV3 saved to: {deeplab_path}") | |
| # Download YOLO Model | |
| logger.info("Downloading YOLO model...") | |
| yolo_model = YOLO("yolov5s.pt") | |
| yolo_path = os.path.join(artifacts_dir, "yolov5s.pt") | |
| # YOLO models are saved differently | |
| yolo_model.save(yolo_path) | |
| logger.info(f"YOLO saved to: {yolo_path}") | |
| logger.info("All models downloaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Error downloading models: {e}") | |
| raise | |
| if __name__ == "__main__": | |
| download_models() |