weights1 / models.py
Sanket17's picture
Update models.py
0f9d8ab verified
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from ultralytics import YOLO
import gdown
import os
from safetensors.torch import load_file # Safetensors loading method
# Set a custom cache directory for gdown
gdown_cache_dir = os.path.join(os.getcwd(), "cache")
os.makedirs(gdown_cache_dir, exist_ok=True)
os.environ["GDOWN_CACHE"] = gdown_cache_dir # Explicitly set GDOWN_CACHE
def download_model_from_drive(file_id, destination_path):
"""Download the model from Google Drive using gdown."""
# Construct the Google Drive download URL
url = f"https://drive.google.com/uc?id={file_id}"
# Ensure the destination directory exists
directory = os.path.dirname(destination_path)
if directory:
os.makedirs(directory, exist_ok=True)
# Download the file
gdown.download(url, destination_path, quiet=False)
def load_models(device='cpu'):
"""Load YOLO model and the caption generation model."""
# Define the file path for the .safetensors model
model_file_path = "model.safetensors" # Adjust based on your file name
# Download the model file if it doesn't exist
if not os.path.exists(model_file_path):
file_id = "1hUCqZ3X8mcM-KcwWFjcsFg7PA0hUvE3k" # Replace with your file ID
print(f"Downloading model to {model_file_path}...")
download_model_from_drive(file_id, model_file_path)
# Load the YOLO model
print("Loading YOLO model...")
yolo_model = YOLO("best.pt").to(device)
# Load the processor for the caption model
print("Loading processor for the caption model...")
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base",
trust_remote_code=True
)
# Load the caption model state dict from .safetensors
print("Loading caption generation model...")
model_state_dict = load_file(model_file_path) # Load tensors from .safetensors
caption_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base",
trust_remote_code=True
)
caption_model.load_state_dict(model_state_dict) # Map tensors to the model
caption_model.to(device) # Move the model to the correct device
print("Models loaded successfully!")
return {
'yolo_model': yolo_model,
'processor': processor,
'caption_model': caption_model
}
# Usage example:
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
models = load_models(device=device)
print("All models are ready to use!")