Spaces:
Build error
Build error
| """ | |
| Module quản lý việc load và quản lý models (LLM, STT) | |
| """ | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| AutoConfig, | |
| Qwen3VLForConditionalGeneration, | |
| AutoProcessor | |
| ) | |
| from config import MODEL_NAME, USE_FASTER_WHISPER, WHISPER_MODEL_SIZE, MAX_GPU_MEMORY, MAX_CPU_MEMORY | |
| # Global model instances | |
| model = None | |
| tokenizer = None | |
| processor = None | |
| whisper_model = None | |
| is_qwen3vl = False | |
| def load_llm_model(): | |
| """Load LLM model và tokenizer""" | |
| global model, tokenizer, processor, is_qwen3vl | |
| try: | |
| print(f"Đang tải model {MODEL_NAME}...") | |
| # Thử load config trước | |
| try: | |
| config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model_type = getattr(config, 'model_type', 'Not specified') | |
| print(f"Config loaded. Model type: {model_type}") | |
| except Exception as config_error: | |
| print(f"⚠️ Không thể load config: {config_error}") | |
| config = None | |
| model_type = None | |
| # Kiểm tra GPU | |
| if torch.cuda.is_available(): | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") | |
| # Load model với cấu hình tối đa | |
| load_kwargs = { | |
| "low_cpu_mem_usage": True, | |
| "dtype": torch.bfloat16, | |
| "device_map": "auto", | |
| "trust_remote_code": True, | |
| "attn_implementation": "sdpa", | |
| "max_memory": {0: MAX_GPU_MEMORY, "cpu": MAX_CPU_MEMORY} if torch.cuda.is_available() else None, | |
| } | |
| # Tối ưu CUDA settings | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| print("⚡ Đã bật CUDNN benchmark và TF32 để tăng tốc độ") | |
| # Kiểm tra nếu là Qwen3VL model | |
| if model_type == "qwen3_vl": | |
| is_qwen3vl = True | |
| print("⚠️ Phát hiện Qwen3VL model. Đang load với Qwen3VLForConditionalGeneration...") | |
| try: | |
| model = Qwen3VLForConditionalGeneration.from_pretrained(MODEL_NAME, **load_kwargs) | |
| try: | |
| processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| print("✅ Load Qwen3VL processor thành công!") | |
| except: | |
| print("⚠️ Không thể load processor, sử dụng tokenizer...") | |
| processor = None | |
| print("✅ Load Qwen3VL model thành công!") | |
| except Exception as vl_error: | |
| print(f"❌ Lỗi khi load Qwen3VL: {vl_error}") | |
| raise vl_error | |
| else: | |
| # Load với AutoModelForCausalLM | |
| try: | |
| if config: | |
| load_kwargs["config"] = config | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **load_kwargs) | |
| # Tối ưu: Compile model (chỉ trên GPU mạnh) | |
| try: | |
| if hasattr(torch, 'compile') and torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0).lower() | |
| if any(x in gpu_name for x in ['a100', 'h100', 'l4', 'h200', 'v100']): | |
| print("⚡ Đang compile model với torch.compile...") | |
| model = torch.compile(model, mode="reduce-overhead", fullgraph=False, dynamic=False) | |
| print("✅ Model đã được compile thành công!") | |
| else: | |
| print(f"⚠️ GPU {gpu_name}: Bỏ qua torch.compile để tránh overhead") | |
| except Exception as compile_error: | |
| print(f"⚠️ Không thể compile model: {compile_error}") | |
| # Tối ưu: Enable SDPA | |
| try: | |
| if hasattr(model.config, 'attn_implementation'): | |
| model.config.attn_implementation = "sdpa" | |
| print("⚡ Đã bật SDPA (Scaled Dot Product Attention)") | |
| except: | |
| pass | |
| # Tối ưu: Pre-warm model (chỉ trên GPU mạnh) | |
| try: | |
| if torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0).lower() | |
| if any(x in gpu_name for x in ['a100', 'h100', 'l4', 'h200']): | |
| print("⚡ Đang pre-warm model...") | |
| dummy_input = torch.randint(0, 1000, (1, 10), device=model.device) | |
| with torch.inference_mode(): | |
| _ = model.generate(dummy_input, max_new_tokens=1, use_cache=True) | |
| print("✅ Model đã được pre-warm thành công!") | |
| else: | |
| print("⚠️ Bỏ qua pre-warm để tăng tốc startup") | |
| except Exception as warm_error: | |
| print(f"⚠️ Không thể pre-warm model: {warm_error}") | |
| print("✅ Load model với AutoModelForCausalLM thành công!") | |
| except Exception as model_error: | |
| error_msg = str(model_error) | |
| if "Unrecognized" in error_msg or "model_type" in error_msg.lower(): | |
| print("⚠️ Model không được nhận diện. Thử load với AutoModel...") | |
| try: | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained(MODEL_NAME, **load_kwargs) | |
| print("✅ Load model với AutoModel thành công!") | |
| except Exception as auto_error: | |
| print(f"❌ Không thể load với AutoModel: {auto_error}") | |
| raise model_error | |
| else: | |
| raise model_error | |
| # Load tokenizer | |
| if processor is None: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| else: | |
| tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| print("✅ LLM Model đã được tải thành công!") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Lỗi khi tải LLM model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def load_stt_model(): | |
| """Load STT (Speech-to-Text) model""" | |
| global whisper_model | |
| try: | |
| if USE_FASTER_WHISPER: | |
| # faster-whisper: nhanh hơn 4-5x | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| compute_type = "float16" if device == "cuda" else "int8" | |
| print(f"⚡ Đang tải faster-whisper ({WHISPER_MODEL_SIZE}) trên {device}...") | |
| from faster_whisper import WhisperModel | |
| whisper_model = WhisperModel( | |
| WHISPER_MODEL_SIZE, | |
| device=device, | |
| compute_type=compute_type, | |
| device_index=0, | |
| num_workers=4 if device == "cuda" else 1, | |
| ) | |
| print(f"✅ faster-whisper ({WHISPER_MODEL_SIZE}) đã được tải thành công!") | |
| else: | |
| # Fallback: openai-whisper | |
| import whisper | |
| print(f"⚡ Đang tải Whisper model ({WHISPER_MODEL_SIZE}) cho STT...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device=device) | |
| if device == "cuda": | |
| whisper_model = whisper_model.half() | |
| print(f"✅ Whisper model ({WHISPER_MODEL_SIZE}) đã được tải thành công!") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Lỗi khi tải STT model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def load_all_models(): | |
| """Load tất cả models""" | |
| print("=" * 50) | |
| print("BẮT ĐẦU TẢI MODELS...") | |
| print("=" * 50) | |
| llm_loaded = load_llm_model() | |
| stt_loaded = load_stt_model() | |
| if llm_loaded and stt_loaded: | |
| print("=" * 50) | |
| print("✅ TẤT CẢ MODELS ĐÃ ĐƯỢC TẢI THÀNH CÔNG!") | |
| print("=" * 50) | |
| return True | |
| else: | |
| print("=" * 50) | |
| print("⚠️ CẢNH BÁO: CÓ LỖI KHI TẢI MODELS!") | |
| print("=" * 50) | |
| return False | |
| def get_model(): | |
| """Lấy LLM model instance""" | |
| return model | |
| def get_tokenizer(): | |
| """Lấy tokenizer instance""" | |
| return tokenizer | |
| def get_processor(): | |
| """Lấy processor instance (nếu có)""" | |
| return processor | |
| def get_whisper_model(): | |
| """Lấy Whisper model instance""" | |
| return whisper_model | |
| def is_qwen3vl_model(): | |
| """Kiểm tra xem có phải Qwen3VL model không""" | |
| return is_qwen3vl | |