banking2b / models.py
hainc
Refactor: Tách code thành cấu trúc module chuyên nghiệp
124d692
"""
Module quản lý việc load và quản lý models (LLM, STT)
"""
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
Qwen3VLForConditionalGeneration,
AutoProcessor
)
from config import MODEL_NAME, USE_FASTER_WHISPER, WHISPER_MODEL_SIZE, MAX_GPU_MEMORY, MAX_CPU_MEMORY
# Global model instances
model = None
tokenizer = None
processor = None
whisper_model = None
is_qwen3vl = False
def load_llm_model():
"""Load LLM model và tokenizer"""
global model, tokenizer, processor, is_qwen3vl
try:
print(f"Đang tải model {MODEL_NAME}...")
# Thử load config trước
try:
config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
model_type = getattr(config, 'model_type', 'Not specified')
print(f"Config loaded. Model type: {model_type}")
except Exception as config_error:
print(f"⚠️ Không thể load config: {config_error}")
config = None
model_type = None
# Kiểm tra GPU
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
# Load model với cấu hình tối đa
load_kwargs = {
"low_cpu_mem_usage": True,
"dtype": torch.bfloat16,
"device_map": "auto",
"trust_remote_code": True,
"attn_implementation": "sdpa",
"max_memory": {0: MAX_GPU_MEMORY, "cpu": MAX_CPU_MEMORY} if torch.cuda.is_available() else None,
}
# Tối ưu CUDA settings
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print("⚡ Đã bật CUDNN benchmark và TF32 để tăng tốc độ")
# Kiểm tra nếu là Qwen3VL model
if model_type == "qwen3_vl":
is_qwen3vl = True
print("⚠️ Phát hiện Qwen3VL model. Đang load với Qwen3VLForConditionalGeneration...")
try:
model = Qwen3VLForConditionalGeneration.from_pretrained(MODEL_NAME, **load_kwargs)
try:
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
print("✅ Load Qwen3VL processor thành công!")
except:
print("⚠️ Không thể load processor, sử dụng tokenizer...")
processor = None
print("✅ Load Qwen3VL model thành công!")
except Exception as vl_error:
print(f"❌ Lỗi khi load Qwen3VL: {vl_error}")
raise vl_error
else:
# Load với AutoModelForCausalLM
try:
if config:
load_kwargs["config"] = config
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **load_kwargs)
# Tối ưu: Compile model (chỉ trên GPU mạnh)
try:
if hasattr(torch, 'compile') and torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0).lower()
if any(x in gpu_name for x in ['a100', 'h100', 'l4', 'h200', 'v100']):
print("⚡ Đang compile model với torch.compile...")
model = torch.compile(model, mode="reduce-overhead", fullgraph=False, dynamic=False)
print("✅ Model đã được compile thành công!")
else:
print(f"⚠️ GPU {gpu_name}: Bỏ qua torch.compile để tránh overhead")
except Exception as compile_error:
print(f"⚠️ Không thể compile model: {compile_error}")
# Tối ưu: Enable SDPA
try:
if hasattr(model.config, 'attn_implementation'):
model.config.attn_implementation = "sdpa"
print("⚡ Đã bật SDPA (Scaled Dot Product Attention)")
except:
pass
# Tối ưu: Pre-warm model (chỉ trên GPU mạnh)
try:
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0).lower()
if any(x in gpu_name for x in ['a100', 'h100', 'l4', 'h200']):
print("⚡ Đang pre-warm model...")
dummy_input = torch.randint(0, 1000, (1, 10), device=model.device)
with torch.inference_mode():
_ = model.generate(dummy_input, max_new_tokens=1, use_cache=True)
print("✅ Model đã được pre-warm thành công!")
else:
print("⚠️ Bỏ qua pre-warm để tăng tốc startup")
except Exception as warm_error:
print(f"⚠️ Không thể pre-warm model: {warm_error}")
print("✅ Load model với AutoModelForCausalLM thành công!")
except Exception as model_error:
error_msg = str(model_error)
if "Unrecognized" in error_msg or "model_type" in error_msg.lower():
print("⚠️ Model không được nhận diện. Thử load với AutoModel...")
try:
from transformers import AutoModel
model = AutoModel.from_pretrained(MODEL_NAME, **load_kwargs)
print("✅ Load model với AutoModel thành công!")
except Exception as auto_error:
print(f"❌ Không thể load với AutoModel: {auto_error}")
raise model_error
else:
raise model_error
# Load tokenizer
if processor is None:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
else:
tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
print("✅ LLM Model đã được tải thành công!")
return True
except Exception as e:
print(f"❌ Lỗi khi tải LLM model: {e}")
import traceback
traceback.print_exc()
return False
def load_stt_model():
"""Load STT (Speech-to-Text) model"""
global whisper_model
try:
if USE_FASTER_WHISPER:
# faster-whisper: nhanh hơn 4-5x
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
print(f"⚡ Đang tải faster-whisper ({WHISPER_MODEL_SIZE}) trên {device}...")
from faster_whisper import WhisperModel
whisper_model = WhisperModel(
WHISPER_MODEL_SIZE,
device=device,
compute_type=compute_type,
device_index=0,
num_workers=4 if device == "cuda" else 1,
)
print(f"✅ faster-whisper ({WHISPER_MODEL_SIZE}) đã được tải thành công!")
else:
# Fallback: openai-whisper
import whisper
print(f"⚡ Đang tải Whisper model ({WHISPER_MODEL_SIZE}) cho STT...")
device = "cuda" if torch.cuda.is_available() else "cpu"
whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device=device)
if device == "cuda":
whisper_model = whisper_model.half()
print(f"✅ Whisper model ({WHISPER_MODEL_SIZE}) đã được tải thành công!")
return True
except Exception as e:
print(f"❌ Lỗi khi tải STT model: {e}")
import traceback
traceback.print_exc()
return False
def load_all_models():
"""Load tất cả models"""
print("=" * 50)
print("BẮT ĐẦU TẢI MODELS...")
print("=" * 50)
llm_loaded = load_llm_model()
stt_loaded = load_stt_model()
if llm_loaded and stt_loaded:
print("=" * 50)
print("✅ TẤT CẢ MODELS ĐÃ ĐƯỢC TẢI THÀNH CÔNG!")
print("=" * 50)
return True
else:
print("=" * 50)
print("⚠️ CẢNH BÁO: CÓ LỖI KHI TẢI MODELS!")
print("=" * 50)
return False
def get_model():
"""Lấy LLM model instance"""
return model
def get_tokenizer():
"""Lấy tokenizer instance"""
return tokenizer
def get_processor():
"""Lấy processor instance (nếu có)"""
return processor
def get_whisper_model():
"""Lấy Whisper model instance"""
return whisper_model
def is_qwen3vl_model():
"""Kiểm tra xem có phải Qwen3VL model không"""
return is_qwen3vl