weights / models.py
Sanket17's picture
Update models.py
426b1bc verified
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from ultralytics import YOLO
import gdown
import os
from safetensors.torch import load_file # Safetensors loading method
def download_model_from_drive(file_id, destination_path):
"""Download the model from Google Drive using gdown."""
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, destination_path, quiet=False)
def load_models(device='cpu'):
"""Initialize and load all required models."""
# Set default dtype for torch
torch.set_default_dtype(torch.float32)
# Download the model from Google Drive (if not already present)
model_file_path = 'model.safetensors' # Use the correct model file name
if not os.path.exists(model_file_path):
file_id = "1hUCqZ3X8mcM-KcwWFjcsFg7PA0hUvE3k" # Replace with your Google Drive file ID
download_model_from_drive(file_id, model_file_path)
# Load the YOLO model
yolo_model = YOLO('best.pt').to(device)
# Load processor and caption model
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base",
trust_remote_code=True
)
# Load the caption model from the downloaded .safetensors file
# Use safetensors library to load the model
caption_model = load_file(model_file_path, framework="pt", device=device)
return {
'yolo_model': yolo_model,
'processor': processor,
'caption_model': caption_model
}