homer7676 commited on
Commit
bd9671f
·
verified ·
1 Parent(s): a5262a6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -6
handler.py CHANGED
@@ -7,18 +7,18 @@ logging.basicConfig(level=logging.INFO)
7
  logger = logging.getLogger(__name__)
8
 
9
  class EndpointHandler:
10
- def __init__(self):
 
 
11
  self.tokenizer = None
12
  self.model = None
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
  logger.info(f"使用設備: {self.device}")
15
 
16
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: # 修改返回類型
17
- """處理輸入並返回符合要求的格式"""
18
  try:
19
  inputs = self.preprocess(data)
20
  outputs = self.inference(inputs)
21
- # 確保返回值是列表格式
22
  return [outputs]
23
  except Exception as e:
24
  logger.error(f"處理過程錯誤: {str(e)}")
@@ -29,12 +29,12 @@ class EndpointHandler:
29
  logger.info("開始初始化模型")
30
  try:
31
  self.tokenizer = AutoTokenizer.from_pretrained(
32
- "homer7676/FrierenChatbotV1",
33
  trust_remote_code=True
34
  )
35
 
36
  self.model = AutoModelForCausalLM.from_pretrained(
37
- "homer7676/FrierenChatbotV1",
38
  trust_remote_code=True,
39
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
40
  ).to(self.device)
 
7
  logger = logging.getLogger(__name__)
8
 
9
  class EndpointHandler:
10
+ def __init__(self, model_dir: str = None): # 修改這裡,添加類型提示和默認值
11
+ logger.info(f"初始化 EndpointHandler,model_dir: {model_dir}")
12
+ self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1"
13
  self.tokenizer = None
14
  self.model = None
15
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
  logger.info(f"使用設備: {self.device}")
17
 
18
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
 
19
  try:
20
  inputs = self.preprocess(data)
21
  outputs = self.inference(inputs)
 
22
  return [outputs]
23
  except Exception as e:
24
  logger.error(f"處理過程錯誤: {str(e)}")
 
29
  logger.info("開始初始化模型")
30
  try:
31
  self.tokenizer = AutoTokenizer.from_pretrained(
32
+ self.model_dir, # 使用 model_dir
33
  trust_remote_code=True
34
  )
35
 
36
  self.model = AutoModelForCausalLM.from_pretrained(
37
+ self.model_dir, # 使用 model_dir
38
  trust_remote_code=True,
39
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
40
  ).to(self.device)