File size: 1,494 Bytes
19604c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import torch
from transformers import (
    DetrImageProcessor, 
    DetrForObjectDetection,
    ViTImageProcessor,
    ViTForImageClassification,
    CLIPProcessor,
    CLIPModel,
    AutoTokenizer,
    AutoModel,
    AutoModelForQuestionAnswering,
    AutoModelForSeq2SeqLM,
    BartForConditionalGeneration
)

# Set timeout
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "120"

def preload():
    print("πŸš€ Starting model pre-loading...")
    
    models = {
        "detection": ("facebook/detr-resnet-50", DetrForObjectDetection, DetrImageProcessor),
        "reid": ("google/vit-base-patch16-224", ViTForImageClassification, ViTImageProcessor),
        "clip": ("openai/clip-vit-base-patch32", CLIPModel, CLIPProcessor),
        "search": ("sentence-transformers/all-MiniLM-L6-v2", AutoModel, AutoTokenizer),
        "qa": ("deepset/roberta-base-squad2", AutoModelForQuestionAnswering, AutoTokenizer),
        "report": ("google/flan-t5-base", AutoModelForSeq2SeqLM, AutoTokenizer),
        "summarizer": ("facebook/bart-large-cnn", BartForConditionalGeneration, AutoTokenizer),
    }

    for name, (model_id, model_cls, proc_cls) in models.items():
        print(f"πŸ“¦ Pre-loading {name}: {model_id}...")
        try:
            proc_cls.from_pretrained(model_id)
            model_cls.from_pretrained(model_id)
            print(f"βœ… {name} loaded.")
        except Exception as e:
            print(f"❌ Failed to load {name}: {e}")

if __name__ == "__main__":
    preload()