| |
|
|
| import torch |
| import os |
| import logging |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| from typing import Optional |
|
|
| import config |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class ModelManager: |
| """模型管理类,提供更好的错误处理和状态管理""" |
| |
| def __init__(self): |
| self.model = None |
| self.tokenizer = None |
| self.generator = None |
| self.is_loaded = False |
| |
| def load_model(self) -> bool: |
| """加载模型,返回是否成功""" |
| try: |
| |
| hf_token = os.environ.get("HF_TOKEN") |
| if not hf_token: |
| logger.warning("未找到 HF_TOKEN 环境变量,某些模型可能无法访问") |
| else: |
| logger.info(f"找到 HF_TOKEN: {hf_token[:10]}...") |
| |
| logger.info(f"正在加载模型: {config.MODEL_NAME}") |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"Device set to use {device}") |
| |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| config.MODEL_NAME, |
| token=hf_token, |
| dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto" if torch.cuda.is_available() else None, |
| trust_remote_code=True, |
| low_cpu_mem_usage=True |
| ) |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| config.MODEL_NAME, |
| token=hf_token, |
| trust_remote_code=True |
| ) |
| |
| |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| |
| self.generator = pipeline( |
| "text-generation", |
| model=self.model, |
| tokenizer=self.tokenizer, |
| max_new_tokens=config.PIPELINE_MAX_NEW_TOKENS, |
| temperature=config.DEFAULT_TEMPERATURE, |
| do_sample=True, |
| top_p=config.DEFAULT_TOP_P, |
| pad_token_id=self.tokenizer.eos_token_id, |
| return_full_text=False |
| ) |
| |
| self.is_loaded = True |
| logger.info("✅ 模型加载成功!") |
| return True |
| |
| except Exception as e: |
| logger.error(f"❌ 模型加载失败: {e}") |
| self.is_loaded = False |
| return False |
| |
| def generate_response(self, user_message: str, |
| max_tokens: int = None, |
| temperature: float = None, |
| top_p: float = None, |
| top_k: int = None, |
| repetition_penalty: float = None) -> str: |
| """生成回复""" |
| if not self.is_loaded or not self.generator: |
| return "❌ 模型未加载,请检查配置或联系管理员。" |
| |
| try: |
| |
| generation_tokens = max_tokens if max_tokens is not None else config.PIPELINE_MAX_NEW_TOKENS |
| temp = temperature if temperature is not None else config.DEFAULT_TEMPERATURE |
| tp = top_p if top_p is not None else config.DEFAULT_TOP_P |
| tk = top_k if top_k is not None else config.DEFAULT_TOP_K |
| rp = repetition_penalty if repetition_penalty is not None else config.DEFAULT_REPETITION_PENALTY |
| |
| |
| generation_tokens = max(config.MIN_TOKENS, min(config.MAX_TOKENS, generation_tokens)) |
| temp = max(0.1, min(2.0, temp)) |
| tp = max(0.1, min(1.0, tp)) |
| tk = max(1, min(100, tk)) |
| rp = max(1.0, min(2.0, rp)) |
| |
| |
| messages = [{"role": "user", "content": user_message}] |
| |
| |
| if hasattr(self.tokenizer, 'apply_chat_template'): |
| prompt = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| else: |
| |
| prompt = f"User: {user_message}\nAssistant: " |
| |
| |
| outputs = self.generator( |
| prompt, |
| max_new_tokens=generation_tokens, |
| temperature=temp, |
| do_sample=True, |
| top_p=tp, |
| top_k=tk, |
| repetition_penalty=rp, |
| pad_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| |
| if isinstance(outputs, list) and len(outputs) > 0: |
| generated_text = outputs[0].get("generated_text", "") |
| |
| |
| if generated_text.startswith(prompt): |
| response = generated_text[len(prompt):].strip() |
| else: |
| response = generated_text.strip() |
| else: |
| response = "抱歉,生成回复时出现问题。" |
| |
| |
| response = self._clean_response(response) |
| |
| return response if response else "抱歉,我无法生成合适的回复。" |
| |
| except Exception as e: |
| logger.error(f"生成回复时出错: {str(e)}") |
| return f"❌ 生成回复时出错: {str(e)}" |
| |
| def _clean_response(self, text: str) -> str: |
| """清理回复文本中的特殊token""" |
| special_tokens = [ |
| "<end_of_turn>", |
| "<|end|>", |
| "<|endoftext|>", |
| "</s>", |
| "<eos>", |
| "<pad>" |
| ] |
| |
| for token in special_tokens: |
| if token in text: |
| text = text.split(token)[0].strip() |
| |
| return text.strip() |
|
|
| |
| model_manager = ModelManager() |
| model_loaded = model_manager.load_model() |