File size: 3,874 Bytes
1e36ed8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | import os
import torch
import numpy as np
import joblib
# Hugging Face๊ฐ ์ ๊ณตํ๋ ๊ธฐ๋ณธ ํธ๋ค๋ฌ ํด๋์ค๋ฅผ ๊ฐ์ ธ์ต๋๋ค.
from text_generation_server.models.custom_handler import BaseHandler
class TidePredictionHandler(BaseHandler):
"""
TimeXer ๋ชจ๋ธ์ ์ํ ์ปค์คํ
ํธ๋ค๋ฌ
"""
def __init__(self, model, tokenizer):
# ์ด ํจ์๋ ์๋ฒ๊ฐ ์์๋ ๋ ๋ฑ ํ ๋ฒ ์คํ๋ฉ๋๋ค.
# ๋ชจ๋ธ ์ด๊ธฐํ, ์ค์ผ์ผ๋ฌ ๋ก๋ฉ ๋ฑ ์ค๋น ์์
์ ์ฌ๊ธฐ์ ํฉ๋๋ค.
super().__init__(model, tokenizer)
# 1. ์ค์ผ์ผ๋ฌ๋ฅผ ๋ถ๋ฌ์์ self.scaler์ ์ ์ฅํฉ๋๋ค.
# ๊ฒฝ๋ก๋ ์ ์ฅ์ ๋ด์ ์ค์ ํ์ผ ์์น์ ๊ฐ์์ผ ํฉ๋๋ค.
scaler_path = os.path.join(os.getcwd(), 'checkpoints', 'scaler.gz')
self.scaler = joblib.load(scaler_path)
# 2. ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์ ํฉ๋๋ค.
self.model.eval()
# 3. ๋ชจ๋ธ์ ์ค์ ๊ฐ๋ค์ self.model.args ์ฒ๋ผ ์ ๊ทผํ ์ ์๋๋ก ์ ์ฅํด๋๋ฉด ํธ๋ฆฌํฉ๋๋ค.
# ์ด ๋ถ๋ถ์ TimeXer ๋ชจ๋ธ์ ๊ตฌ์กฐ์ ๋ฐ๋ผ ํ์ ์์ ์๋ ์์ต๋๋ค.
# ์: self.seq_len = self.model.seq_len
def __call__(self, inputs, **kwargs):
# ์ด ํจ์๋ API ์์ธก ์์ฒญ์ด ์ฌ ๋๋ง๋ค ์คํ๋ฉ๋๋ค.
# ์ค์ ์์ธก ๋ก์ง์ด ๋ค์ด๊ฐ๋ ๋ถ๋ถ์
๋๋ค.
# 1. ์
๋ ฅ ๋ฐ์ดํฐ ํ์ฑ
# inputs๋ ๋ณดํต ๋ฆฌ์คํธ ํํ์ ํ
์คํธ ๋๋ ๋ฐ์ดํธ ๋ฐ์ดํฐ๋ก ๋ค์ด์ต๋๋ค.
# JSON ํ์์ผ๋ก ์
๋ ฅ์ ๋ฐ์ผ๋ ค๋ฉด ์ถ๊ฐ์ ์ธ ์ฒ๋ฆฌ๊ฐ ํ์ํ ์ ์์ต๋๋ค.
# ์ฌ๊ธฐ์๋ ๊ฐ๋จํ inputs๊ฐ ์ซ์ ๋ฆฌ์คํธ ๋ฌธ์์ด์ด๋ผ๊ณ ๊ฐ์ ํฉ๋๋ค.
# ์: "500.1, 502.3, ..., 498.7"
# ๋ฌธ์์ด์ ์ซ์ ๋ฆฌ์คํธ๋ก ๋ณํ
try:
# ์
๋ ฅ ๋ฐ์ดํฐ๋ฅผ ํ์ฑํ๋ ๊ฐ์ฅ ์ข์ ๋ฐฉ๋ฒ์ JSON์ ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค.
# ์: `json.loads(inputs[0])`
# ์ฌ๊ธฐ์๋ ๊ฐ๋จํ ์์๋ฅผ ์ํด split์ ์ฌ์ฉํฉ๋๋ค.
input_list = [float(i) for i in inputs[0].split(',')]
seq_len = 144 # ์ด ๊ฐ์ ์ค์ ๋ชจ๋ธ์ ์
๋ ฅ ๊ธธ์ด์ ์ผ์นํด์ผ ํฉ๋๋ค.
if len(input_list) != seq_len:
raise ValueError(f"Input must have {seq_len} items.")
except Exception as e:
# ์ค๋ฅ ๋ฐ์ ์ ์๋ฌ ๋ฉ์์ง๋ฅผ ๋ฐํํฉ๋๋ค.
return {"error": f"Invalid input format: {str(e)}"}, 400
# 2. ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ ์
๋ ฅ ํ์(Tensor)์ผ๋ก ๋ณํ
input_array = np.array(input_list).reshape(-1, 1)
scaled_input = self.scaler.transform(input_array)
input_tensor = torch.from_numpy(scaled_input).float().unsqueeze(0).to(self.model.device)
# 3. ๋ชจ๋ธ ์์ธก ์คํ
with torch.no_grad():
# TimeXer ๋ชจ๋ธ์ forward ํจ์์ ํ์ํ ๋ชจ๋ ์ธ์๋ฅผ ์ ๋ฌํด์ผ ํฉ๋๋ค.
# ์: outputs = self.model(batch_x=input_tensor, batch_x_mark=...)
# ์ด ๋ถ๋ถ์ ๋ชจ๋ธ์ ์ค์ ์ฝ๋๋ฅผ ๋ณด๊ณ ์ฑ์์ผ ํฉ๋๋ค.
# ์ฌ๊ธฐ์๋ input_tensor๋ง ํ์ํ๋ค๊ณ ๊ฐ์ ํฉ๋๋ค.
outputs = self.model(input_tensor)
# 4. ์์ธก ๊ฒฐ๊ณผ๋ฅผ ํ์ฒ๋ฆฌํ๊ณ ์๋ ์ค์ผ์ผ๋ก ๋ณต์
prediction_scaled = outputs.detach().cpu().numpy().squeeze()
prediction = self.scaler.inverse_transform(prediction_scaled.reshape(-1, 1))
# 5. ์ต์ข
๊ฒฐ๊ณผ๋ฅผ ๋ฆฌ์คํธ ํํ๋ก ๋ฐํ
# Hugging Face ํธ๋ค๋ฌ๋ ๋ณดํต ํ
์คํธ๋ ๋ฐ์ดํธ๋ฅผ ๋ฐํํด์ผ ํฉ๋๋ค.
# ๊ฒฐ๊ณผ๋ฅผ JSON ๋ฌธ์์ด๋ก ๋ง๋ค์ด ๋ฐํํ๋ ๊ฒ์ด ์ผ๋ฐ์ ์
๋๋ค.
import json
result_str = json.dumps({"prediction": prediction.flatten().tolist()})
return [result_str] |