|
|
import torch |
|
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
from ultralytics import YOLO |
|
|
import gdown |
|
|
import os |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
gdown_cache_dir = os.path.join(os.getcwd(), "cache") |
|
|
os.makedirs(gdown_cache_dir, exist_ok=True) |
|
|
os.environ["GDOWN_CACHE"] = gdown_cache_dir |
|
|
|
|
|
def download_model_from_drive(file_id, destination_path): |
|
|
"""Download the model from Google Drive using gdown.""" |
|
|
|
|
|
url = f"https://drive.google.com/uc?id={file_id}" |
|
|
|
|
|
directory = os.path.dirname(destination_path) |
|
|
if directory: |
|
|
os.makedirs(directory, exist_ok=True) |
|
|
|
|
|
gdown.download(url, destination_path, quiet=False) |
|
|
|
|
|
def load_models(device='cpu'): |
|
|
"""Load YOLO model and the caption generation model.""" |
|
|
|
|
|
model_file_path = "model.safetensors" |
|
|
|
|
|
|
|
|
if not os.path.exists(model_file_path): |
|
|
file_id = "1hUCqZ3X8mcM-KcwWFjcsFg7PA0hUvE3k" |
|
|
print(f"Downloading model to {model_file_path}...") |
|
|
download_model_from_drive(file_id, model_file_path) |
|
|
|
|
|
|
|
|
print("Loading YOLO model...") |
|
|
yolo_model = YOLO("best.pt").to(device) |
|
|
|
|
|
|
|
|
print("Loading processor for the caption model...") |
|
|
processor = AutoProcessor.from_pretrained( |
|
|
"microsoft/Florence-2-base", |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
print("Loading caption generation model...") |
|
|
model_state_dict = load_file(model_file_path) |
|
|
caption_model = AutoModelForCausalLM.from_pretrained( |
|
|
"microsoft/Florence-2-base", |
|
|
trust_remote_code=True |
|
|
) |
|
|
caption_model.load_state_dict(model_state_dict) |
|
|
caption_model.to(device) |
|
|
|
|
|
print("Models loaded successfully!") |
|
|
return { |
|
|
'yolo_model': yolo_model, |
|
|
'processor': processor, |
|
|
'caption_model': caption_model |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
models = load_models(device=device) |
|
|
print("All models are ready to use!") |
|
|
|