gemma-3-270m / model.py
Tom1986's picture
Deploy Gemma 3 270M optimized version
b6a86b5
# -*- coding: utf-8 -*-
import torch
import os
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from typing import Optional
import config
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
"""模型管理类,提供更好的错误处理和状态管理"""
def __init__(self):
self.model = None
self.tokenizer = None
self.generator = None
self.is_loaded = False
def load_model(self) -> bool:
"""加载模型,返回是否成功"""
try:
# 检查HF Token
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
logger.warning("未找到 HF_TOKEN 环境变量,某些模型可能无法访问")
else:
logger.info(f"找到 HF_TOKEN: {hf_token[:10]}...")
logger.info(f"正在加载模型: {config.MODEL_NAME}")
# 设备检测
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Device set to use {device}")
# 加载模型
self.model = AutoModelForCausalLM.from_pretrained(
config.MODEL_NAME,
token=hf_token,
dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
low_cpu_mem_usage=True
)
# 加载分词器
self.tokenizer = AutoTokenizer.from_pretrained(
config.MODEL_NAME,
token=hf_token,
trust_remote_code=True
)
# 设置pad_token(如果不存在)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# 创建生成管道
self.generator = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=config.PIPELINE_MAX_NEW_TOKENS,
temperature=config.DEFAULT_TEMPERATURE,
do_sample=True,
top_p=config.DEFAULT_TOP_P,
pad_token_id=self.tokenizer.eos_token_id,
return_full_text=False
)
self.is_loaded = True
logger.info("✅ 模型加载成功!")
return True
except Exception as e:
logger.error(f"❌ 模型加载失败: {e}")
self.is_loaded = False
return False
def generate_response(self, user_message: str,
max_tokens: int = None,
temperature: float = None,
top_p: float = None,
top_k: int = None,
repetition_penalty: float = None) -> str:
"""生成回复"""
if not self.is_loaded or not self.generator:
return "❌ 模型未加载,请检查配置或联系管理员。"
try:
# 使用默认值或传入的参数
generation_tokens = max_tokens if max_tokens is not None else config.PIPELINE_MAX_NEW_TOKENS
temp = temperature if temperature is not None else config.DEFAULT_TEMPERATURE
tp = top_p if top_p is not None else config.DEFAULT_TOP_P
tk = top_k if top_k is not None else config.DEFAULT_TOP_K
rp = repetition_penalty if repetition_penalty is not None else config.DEFAULT_REPETITION_PENALTY
# 验证参数范围
generation_tokens = max(config.MIN_TOKENS, min(config.MAX_TOKENS, generation_tokens))
temp = max(0.1, min(2.0, temp))
tp = max(0.1, min(1.0, tp))
tk = max(1, min(100, tk))
rp = max(1.0, min(2.0, rp))
# 构建消息格式
messages = [{"role": "user", "content": user_message}]
# 应用聊天模板
if hasattr(self.tokenizer, 'apply_chat_template'):
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
# 备用格式
prompt = f"User: {user_message}\nAssistant: "
# 生成回复
outputs = self.generator(
prompt,
max_new_tokens=generation_tokens,
temperature=temp,
do_sample=True,
top_p=tp,
top_k=tk,
repetition_penalty=rp,
pad_token_id=self.tokenizer.eos_token_id
)
# 提取回复内容
if isinstance(outputs, list) and len(outputs) > 0:
generated_text = outputs[0].get("generated_text", "")
# 如果返回了完整文本,需要移除prompt部分
if generated_text.startswith(prompt):
response = generated_text[len(prompt):].strip()
else:
response = generated_text.strip()
else:
response = "抱歉,生成回复时出现问题。"
# 清理特殊token
response = self._clean_response(response)
return response if response else "抱歉,我无法生成合适的回复。"
except Exception as e:
logger.error(f"生成回复时出错: {str(e)}")
return f"❌ 生成回复时出错: {str(e)}"
def _clean_response(self, text: str) -> str:
"""清理回复文本中的特殊token"""
special_tokens = [
"<end_of_turn>",
"<|end|>",
"<|endoftext|>",
"</s>",
"<eos>",
"<pad>"
]
for token in special_tokens:
if token in text:
text = text.split(token)[0].strip()
return text.strip()
# Initialize and load the model manager
model_manager = ModelManager()
model_loaded = model_manager.load_model()