File size: 2,741 Bytes
e24fc1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import os
import json
import logging
from typing import Optional, Dict, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
"""Initialize the model manager."""
self.current_model = None
self.current_tokenizer = None
self.current_model_name = None
self.load_model()
def load_model(self, model_name: Optional[str] = None) -> bool:
"""
Load the language model, trying different models in order of preference.
Args:
model_name: Optional specific model to load
Returns:
bool: True if any model was loaded successfully
"""
models_to_try = [
model_name
] if model_name else [
"mistralai/Mistral-7B-Instruct-v0.2", # Best balance of capability/size
"microsoft/phi-2", # Fallback
"gpt2" # Last resort
]
for model_id in models_to_try:
try:
logger.info(f"Loading {model_id}")
self.current_tokenizer = AutoTokenizer.from_pretrained(model_id)
self.current_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16, # Use half precision
load_in_8bit=True
)
self.current_model_name = model_id
self.current_model.eval()
logger.info(f"Successfully loaded {model_id}")
return True
except Exception as e:
logger.warning(f"Failed to load {model_id}: {e}")
continue
return False
torch_dtype=getattr(torch, self.config.get('torch_dtype', 'float32'))
)
self.current_model.eval()
self.current_model_name = model_name
logger.info(f"Successfully loaded model {model_name}")
return True
except Exception as e:
logger.error(f"Error loading model {model_name}: {e}")
return False
def get_current_model(self) -> tuple:
"""Get currently loaded model and tokenizer."""
return self.current_model, self.current_tokenizer
def is_model_loaded(self) -> bool:
"""Check if a model is currently loaded."""
return self.current_model is not None and self.current_tokenizer is not None |