Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| import torch | |
| import os | |
| import logging | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from typing import Optional | |
| import config | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| """模型管理类,提供更好的错误处理和状态管理""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.generator = None | |
| self.is_loaded = False | |
| def load_model(self) -> bool: | |
| """加载模型,返回是否成功""" | |
| try: | |
| # 检查HF Token | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| logger.warning("未找到 HF_TOKEN 环境变量,某些模型可能无法访问") | |
| else: | |
| logger.info(f"找到 HF_TOKEN: {hf_token[:10]}...") | |
| logger.info(f"正在加载模型: {config.MODEL_NAME}") | |
| # 设备检测 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Device set to use {device}") | |
| # 加载模型 | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| config.MODEL_NAME, | |
| token=hf_token, | |
| dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| # 加载分词器 | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| config.MODEL_NAME, | |
| token=hf_token, | |
| trust_remote_code=True | |
| ) | |
| # 设置pad_token(如果不存在) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # 创建生成管道 | |
| self.generator = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| max_new_tokens=config.PIPELINE_MAX_NEW_TOKENS, | |
| temperature=config.DEFAULT_TEMPERATURE, | |
| do_sample=True, | |
| top_p=config.DEFAULT_TOP_P, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| return_full_text=False | |
| ) | |
| self.is_loaded = True | |
| logger.info("✅ 模型加载成功!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"❌ 模型加载失败: {e}") | |
| self.is_loaded = False | |
| return False | |
| def generate_response(self, user_message: str, | |
| max_tokens: int = None, | |
| temperature: float = None, | |
| top_p: float = None, | |
| top_k: int = None, | |
| repetition_penalty: float = None) -> str: | |
| """生成回复""" | |
| if not self.is_loaded or not self.generator: | |
| return "❌ 模型未加载,请检查配置或联系管理员。" | |
| try: | |
| # 使用默认值或传入的参数 | |
| generation_tokens = max_tokens if max_tokens is not None else config.PIPELINE_MAX_NEW_TOKENS | |
| temp = temperature if temperature is not None else config.DEFAULT_TEMPERATURE | |
| tp = top_p if top_p is not None else config.DEFAULT_TOP_P | |
| tk = top_k if top_k is not None else config.DEFAULT_TOP_K | |
| rp = repetition_penalty if repetition_penalty is not None else config.DEFAULT_REPETITION_PENALTY | |
| # 验证参数范围 | |
| generation_tokens = max(config.MIN_TOKENS, min(config.MAX_TOKENS, generation_tokens)) | |
| temp = max(0.1, min(2.0, temp)) | |
| tp = max(0.1, min(1.0, tp)) | |
| tk = max(1, min(100, tk)) | |
| rp = max(1.0, min(2.0, rp)) | |
| # 构建消息格式 | |
| messages = [{"role": "user", "content": user_message}] | |
| # 应用聊天模板 | |
| if hasattr(self.tokenizer, 'apply_chat_template'): | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| else: | |
| # 备用格式 | |
| prompt = f"User: {user_message}\nAssistant: " | |
| # 生成回复 | |
| outputs = self.generator( | |
| prompt, | |
| max_new_tokens=generation_tokens, | |
| temperature=temp, | |
| do_sample=True, | |
| top_p=tp, | |
| top_k=tk, | |
| repetition_penalty=rp, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # 提取回复内容 | |
| if isinstance(outputs, list) and len(outputs) > 0: | |
| generated_text = outputs[0].get("generated_text", "") | |
| # 如果返回了完整文本,需要移除prompt部分 | |
| if generated_text.startswith(prompt): | |
| response = generated_text[len(prompt):].strip() | |
| else: | |
| response = generated_text.strip() | |
| else: | |
| response = "抱歉,生成回复时出现问题。" | |
| # 清理特殊token | |
| response = self._clean_response(response) | |
| return response if response else "抱歉,我无法生成合适的回复。" | |
| except Exception as e: | |
| logger.error(f"生成回复时出错: {str(e)}") | |
| return f"❌ 生成回复时出错: {str(e)}" | |
| def _clean_response(self, text: str) -> str: | |
| """清理回复文本中的特殊token""" | |
| special_tokens = [ | |
| "<end_of_turn>", | |
| "<|end|>", | |
| "<|endoftext|>", | |
| "</s>", | |
| "<eos>", | |
| "<pad>" | |
| ] | |
| for token in special_tokens: | |
| if token in text: | |
| text = text.split(token)[0].strip() | |
| return text.strip() | |
| # Initialize and load the model manager | |
| model_manager = ModelManager() | |
| model_loaded = model_manager.load_model() | |