File size: 5,848 Bytes
9bd77e9 4727905 a5262a6 507b3a1 9bd77e9 7de7db9 bd9671f 11ffb79 7de7db9 4233874 507b3a1 4233874 507b3a1 7de7db9 507b3a1 7de7db9 4233874 507b3a1 7de7db9 507b3a1 4233874 507b3a1 4233874 507b3a1 9bd77e9 4233874 a5262a6 507b3a1 e83d9e4 4233874 e83d9e4 4233874 a5262a6 11ffb79 9d029ab 4233874 a5262a6 e83d9e4 9d029ab e83d9e4 9d029ab 4233874 a5262a6 e83d9e4 11ffb79 4233874 e83d9e4 a5262a6 4233874 e83d9e4 a5262a6 4233874 7de7db9 4233874 a5262a6 11ffb79 507b3a1 4233874 7de7db9 4233874 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, Any, List
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir: str = None):
self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"初始化 EndpointHandler,使用設備: {self.device}")
# 在初始化時就載入模型和 tokenizer
try:
logger.info("開始載入 tokenizer 和模型")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_dir,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.model_dir,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(self.device)
self.model.eval()
logger.info("模型和 tokenizer 載入完成")
except Exception as e:
logger.error(f"初始化錯誤: {str(e)}")
raise
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
try:
# 確保 tokenizer 和 model 已經初始化
if self.tokenizer is None or self.model is None:
raise RuntimeError("Tokenizer or model not initialized")
inputs = self.preprocess(data)
outputs = self.inference(inputs)
return [outputs]
except Exception as e:
logger.error(f"處理過程錯誤: {str(e)}")
return [{"error": str(e)}]
def initialize(self, context):
"""確保模型已初始化"""
if self.tokenizer is None or self.model is None:
logger.info("在 initialize 中重新初始化模型")
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_dir,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.model_dir,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(self.device)
self.model.eval()
logger.info("模型重新初始化完成")
except Exception as e:
logger.error(f"模型重新初始化錯誤: {str(e)}")
raise
def inference(self, inputs: Dict[str, Any]) -> Dict[str, str]:
logger.info("開始執行推理")
try:
# 檢查輸入格式
if isinstance(inputs, str):
try:
import json
inputs = json.loads(inputs)
except json.JSONDecodeError:
inputs = {"message": inputs}
# 提取消息和上下文
if isinstance(inputs, dict) and "inputs" in inputs:
inputs = inputs["inputs"]
if isinstance(inputs, str):
try:
import json
inputs = json.loads(inputs)
except json.JSONDecodeError:
inputs = {"message": inputs}
message = inputs.get("message", "")
context = inputs.get("context", "")
logger.info(f"處理消息: {message}, 上下文: {context}")
prompt = f"""你是芙莉蓮,需要遵守以下規則回答:
1. 身份設定:
- 千年精靈魔法師
- 態度溫柔但帶著些許嘲諷
- 說話優雅且有距離感
2. 重要關係:
- 弗蘭梅是我的師傅
- 費倫是我的學生
- 欣梅爾是我的摯友
- 海塔是我的故友
3. 回答規則:
- 使用繁體中文
- 必須提供具體詳細的內容
- 保持回答的連貫性和完整性
相關資訊:{context}
用戶:{message}
芙莉蓮:"""
# 確保 tokenizer 存在
if self.tokenizer is None:
raise RuntimeError("Tokenizer not initialized")
tokens = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**tokens,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
top_k=50,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("芙莉蓮:")[-1].strip()
if not response:
response = "唔...讓我思考一下如何回答你的問題。"
logger.info(f"生成回應: {response}")
return {"generated_text": response}
except Exception as e:
logger.error(f"推理過程錯誤: {str(e)}")
return {"error": str(e)}
def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
logger.info(f"預處理輸入數據: {data}")
return data |