# -*- coding: utf-8 -*- import torch import os import logging from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from typing import Optional import config # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelManager: """模型管理类,提供更好的错误处理和状态管理""" def __init__(self): self.model = None self.tokenizer = None self.generator = None self.is_loaded = False def load_model(self) -> bool: """加载模型,返回是否成功""" try: # 检查HF Token hf_token = os.environ.get("HF_TOKEN") if not hf_token: logger.warning("未找到 HF_TOKEN 环境变量,某些模型可能无法访问") else: logger.info(f"找到 HF_TOKEN: {hf_token[:10]}...") logger.info(f"正在加载模型: {config.MODEL_NAME}") # 设备检测 device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Device set to use {device}") # 加载模型 self.model = AutoModelForCausalLM.from_pretrained( config.MODEL_NAME, token=hf_token, dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, low_cpu_mem_usage=True ) # 加载分词器 self.tokenizer = AutoTokenizer.from_pretrained( config.MODEL_NAME, token=hf_token, trust_remote_code=True ) # 设置pad_token(如果不存在) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # 创建生成管道 self.generator = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, max_new_tokens=config.PIPELINE_MAX_NEW_TOKENS, temperature=config.DEFAULT_TEMPERATURE, do_sample=True, top_p=config.DEFAULT_TOP_P, pad_token_id=self.tokenizer.eos_token_id, return_full_text=False ) self.is_loaded = True logger.info("✅ 模型加载成功!") return True except Exception as e: logger.error(f"❌ 模型加载失败: {e}") self.is_loaded = False return False def generate_response(self, user_message: str, max_tokens: int = None, temperature: float = None, top_p: float = None, top_k: int = None, repetition_penalty: float = None) -> str: """生成回复""" if not self.is_loaded or not self.generator: return "❌ 模型未加载,请检查配置或联系管理员。" try: # 使用默认值或传入的参数 generation_tokens = max_tokens if max_tokens is not None else config.PIPELINE_MAX_NEW_TOKENS temp = temperature if temperature is not None else config.DEFAULT_TEMPERATURE tp = top_p if top_p is not None else config.DEFAULT_TOP_P tk = top_k if top_k is not None else config.DEFAULT_TOP_K rp = repetition_penalty if repetition_penalty is not None else config.DEFAULT_REPETITION_PENALTY # 验证参数范围 generation_tokens = max(config.MIN_TOKENS, min(config.MAX_TOKENS, generation_tokens)) temp = max(0.1, min(2.0, temp)) tp = max(0.1, min(1.0, tp)) tk = max(1, min(100, tk)) rp = max(1.0, min(2.0, rp)) # 构建消息格式 messages = [{"role": "user", "content": user_message}] # 应用聊天模板 if hasattr(self.tokenizer, 'apply_chat_template'): prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) else: # 备用格式 prompt = f"User: {user_message}\nAssistant: " # 生成回复 outputs = self.generator( prompt, max_new_tokens=generation_tokens, temperature=temp, do_sample=True, top_p=tp, top_k=tk, repetition_penalty=rp, pad_token_id=self.tokenizer.eos_token_id ) # 提取回复内容 if isinstance(outputs, list) and len(outputs) > 0: generated_text = outputs[0].get("generated_text", "") # 如果返回了完整文本,需要移除prompt部分 if generated_text.startswith(prompt): response = generated_text[len(prompt):].strip() else: response = generated_text.strip() else: response = "抱歉,生成回复时出现问题。" # 清理特殊token response = self._clean_response(response) return response if response else "抱歉,我无法生成合适的回复。" except Exception as e: logger.error(f"生成回复时出错: {str(e)}") return f"❌ 生成回复时出错: {str(e)}" def _clean_response(self, text: str) -> str: """清理回复文本中的特殊token""" special_tokens = [ "", "<|end|>", "<|endoftext|>", "", "", "" ] for token in special_tokens: if token in text: text = text.split(token)[0].strip() return text.strip() # Initialize and load the model manager model_manager = ModelManager() model_loaded = model_manager.load_model()