Knights-Lab-Assessment / artifact_seeder.py
nasr7322's picture
models seeder to be removed later
8d97257
import os
import torch
import torchvision.transforms as T
from torchvision.models import resnet50
from torchvision.models.segmentation import deeplabv3_resnet50
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
from ultralytics import YOLO
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def download_models():
"""Download and save models to the artifacts directory"""
# Define the artifacts directory
artifacts_dir = os.path.join("prediction", "artifacts")
# Create directory if it doesn't exist
os.makedirs(artifacts_dir, exist_ok=True)
logger.info(f"Artifacts directory: {artifacts_dir}")
try:
# Download ResNet50 Classification Model
logger.info("Downloading ResNet50 model...")
cls_model = resnet50(weights="IMAGENET1K_V1")
cls_model.eval()
resnet_path = os.path.join(artifacts_dir, "resnet50_imagenet.pth")
torch.save(cls_model.state_dict(), resnet_path)
logger.info(f"ResNet50 saved to: {resnet_path}")
# Download DeepLabV3 Segmentation Model
logger.info("Downloading DeepLabV3 model...")
seg_model = deeplabv3_resnet50(weights="DEFAULT")
seg_model.eval()
deeplab_path = os.path.join(artifacts_dir, "deeplabv3_resnet50.pth")
torch.save(seg_model.state_dict(), deeplab_path)
logger.info(f"DeepLabV3 saved to: {deeplab_path}")
# Download YOLO Model
logger.info("Downloading YOLO model...")
yolo_model = YOLO("yolov5s.pt")
yolo_path = os.path.join(artifacts_dir, "yolov5s.pt")
# YOLO models are saved differently
yolo_model.save(yolo_path)
logger.info(f"YOLO saved to: {yolo_path}")
logger.info("All models downloaded successfully!")
except Exception as e:
logger.error(f"Error downloading models: {e}")
raise
if __name__ == "__main__":
download_models()