alwaysgood commited on
Commit
1e36ed8
·
verified ·
1 Parent(s): 3d6797f

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +78 -0
handler.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import joblib
5
+
6
+ # Hugging Face가 제공하는 기본 핸들러 클래스를 가져옵니다.
7
+ from text_generation_server.models.custom_handler import BaseHandler
8
+
9
+ class TidePredictionHandler(BaseHandler):
10
+ """
11
+ TimeXer 모델을 위한 커스텀 핸들러
12
+ """
13
+ def __init__(self, model, tokenizer):
14
+ # 이 함수는 서버가 시작될 때 딱 한 번 실행됩니다.
15
+ # 모델 초기화, 스케일러 로딩 등 준비 작업을 여기서 합니다.
16
+ super().__init__(model, tokenizer)
17
+
18
+ # 1. 스케일러를 불러와서 self.scaler에 저장합니다.
19
+ # 경로는 저장소 내의 실제 파일 위치와 같아야 합니다.
20
+ scaler_path = os.path.join(os.getcwd(), 'checkpoints', 'scaler.gz')
21
+ self.scaler = joblib.load(scaler_path)
22
+
23
+ # 2. 모델을 평가 모드로 설정합니다.
24
+ self.model.eval()
25
+
26
+ # 3. 모델의 설정값들을 self.model.args 처럼 접근할 수 있도록 저장해두면 편리합니다.
27
+ # 이 부분은 TimeXer 모델의 구조에 따라 필요 없을 수도 있습니다.
28
+ # 예: self.seq_len = self.model.seq_len
29
+
30
+ def __call__(self, inputs, **kwargs):
31
+ # 이 함수는 API 예측 요청이 올 때마다 실행됩니다.
32
+ # 실제 예측 로직이 들어가는 부분입니다.
33
+
34
+ # 1. 입력 데이터 파싱
35
+ # inputs는 보통 리스트 형태의 텍스트 또는 바이트 데이터로 들어옵니다.
36
+ # JSON 형식으로 입력을 받으려면 추가적인 처리가 필요할 수 있습니다.
37
+ # 여기서는 간단히 inputs가 숫자 리스트 문자열이라고 가정합니다.
38
+ # 예: "500.1, 502.3, ..., 498.7"
39
+
40
+ # 문자열을 숫자 리스트로 변환
41
+ try:
42
+ # 입력 데이터를 파싱하는 가장 좋은 방법은 JSON을 사용하는 것입니다.
43
+ # 예: `json.loads(inputs[0])`
44
+ # 여기서는 간단한 예시를 위해 split을 사용합니다.
45
+ input_list = [float(i) for i in inputs[0].split(',')]
46
+ seq_len = 144 # 이 값은 실제 모델의 입력 길이와 일치해야 합니다.
47
+
48
+ if len(input_list) != seq_len:
49
+ raise ValueError(f"Input must have {seq_len} items.")
50
+
51
+ except Exception as e:
52
+ # 오류 발생 시 에러 메시지를 반환합니다.
53
+ return {"error": f"Invalid input format: {str(e)}"}, 400
54
+
55
+ # 2. 데이터를 모델 입력 형식(Tensor)으로 변환
56
+ input_array = np.array(input_list).reshape(-1, 1)
57
+ scaled_input = self.scaler.transform(input_array)
58
+ input_tensor = torch.from_numpy(scaled_input).float().unsqueeze(0).to(self.model.device)
59
+
60
+ # 3. 모델 예측 실행
61
+ with torch.no_grad():
62
+ # TimeXer 모델의 forward 함수에 필요한 모든 인자를 전달해야 합니다.
63
+ # 예: outputs = self.model(batch_x=input_tensor, batch_x_mark=...)
64
+ # 이 부분은 모델의 실제 코드를 보고 채워야 합니다.
65
+ # 여기서는 input_tensor만 필요하다고 가정합니다.
66
+ outputs = self.model(input_tensor)
67
+
68
+ # 4. 예측 결과를 후처리하고 원래 스케일로 복원
69
+ prediction_scaled = outputs.detach().cpu().numpy().squeeze()
70
+ prediction = self.scaler.inverse_transform(prediction_scaled.reshape(-1, 1))
71
+
72
+ # 5. 최종 결과를 리스트 형태로 반환
73
+ # Hugging Face 핸들러는 보통 텍스트나 바이트를 반환해야 합니다.
74
+ # 결과를 JSON 문자열로 만들어 반환하는 것이 일반적입니다.
75
+ import json
76
+ result_str = json.dumps({"prediction": prediction.flatten().tolist()})
77
+
78
+ return [result_str]