File size: 2,697 Bytes
fad5c32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""

Model Manager - Handles loading and caching of AI models

Manages Whisper and NLLB-200 models with GPU optimization

"""

import torch
import logging
from typing import Optional
import whisper
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

logger = logging.getLogger(__name__)


class ModelManager:
    """Singleton class to manage model instances and caching"""
    
    _instance = None
    _models = {}
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelManager, cls).__new__(cls)
            cls._instance._initialized = False
        return cls._instance
    
    def __init__(self):
        if self._initialized:
            return
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Using device: {self.device}")
        
        self.whisper_model = None
        self.nllb_tokenizer = None
        self.nllb_model = None
        
        self._initialized = True
    
    def get_whisper_model(self, model_size: str = "large") -> whisper.Whisper:
        """Load Whisper transcription model"""
        if self.whisper_model is None:
            logger.info(f"Loading Whisper {model_size} model...")
            self.whisper_model = whisper.load_model(model_size, device=self.device)
        return self.whisper_model
    
    def get_nllb_model(self, model_name: str = "facebook/nllb-200-distilled-600M"):
        """Load NLLB-200 translation model"""
        if self.nllb_model is None:
            logger.info(f"Loading NLLB-200 model: {model_name}")
            self.nllb_tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                use_auth_token=True,
                trust_remote_code=True
            )
            self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None
            )
            if self.device == "cpu":
                self.nllb_model = self.nllb_model.to(self.device)
        
        return self.nllb_model, self.nllb_tokenizer
    
    def get_device(self) -> str:
        """Get current device (cuda or cpu)"""
        return self.device
    
    def unload_all(self):
        """Unload all models to free memory"""
        logger.info("Unloading all models...")
        self.whisper_model = None
        self.nllb_model = None
        self.nllb_tokenizer = None
        if torch.cuda.is_available():
            torch.cuda.empty_cache()