File size: 3,874 Bytes
1e36ed8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import torch
import numpy as np
import joblib

# Hugging Face๊ฐ€ ์ œ๊ณตํ•˜๋Š” ๊ธฐ๋ณธ ํ•ธ๋“ค๋Ÿฌ ํด๋ž˜์Šค๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
from text_generation_server.models.custom_handler import BaseHandler

class TidePredictionHandler(BaseHandler):
    """
    TimeXer ๋ชจ๋ธ์„ ์œ„ํ•œ ์ปค์Šคํ…€ ํ•ธ๋“ค๋Ÿฌ
    """
    def __init__(self, model, tokenizer):
        # ์ด ํ•จ์ˆ˜๋Š” ์„œ๋ฒ„๊ฐ€ ์‹œ์ž‘๋  ๋•Œ ๋”ฑ ํ•œ ๋ฒˆ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
        # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”, ์Šค์ผ€์ผ๋Ÿฌ ๋กœ๋”ฉ ๋“ฑ ์ค€๋น„ ์ž‘์—…์„ ์—ฌ๊ธฐ์„œ ํ•ฉ๋‹ˆ๋‹ค.
        super().__init__(model, tokenizer)
        
        # 1. ์Šค์ผ€์ผ๋Ÿฌ๋ฅผ ๋ถˆ๋Ÿฌ์™€์„œ self.scaler์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
        #    ๊ฒฝ๋กœ๋Š” ์ €์žฅ์†Œ ๋‚ด์˜ ์‹ค์ œ ํŒŒ์ผ ์œ„์น˜์™€ ๊ฐ™์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.
        scaler_path = os.path.join(os.getcwd(), 'checkpoints', 'scaler.gz')
        self.scaler = joblib.load(scaler_path)
        
        # 2. ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
        self.model.eval()
        
        # 3. ๋ชจ๋ธ์˜ ์„ค์ •๊ฐ’๋“ค์„ self.model.args ์ฒ˜๋Ÿผ ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๋„๋ก ์ €์žฅํ•ด๋‘๋ฉด ํŽธ๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
        #    ์ด ๋ถ€๋ถ„์€ TimeXer ๋ชจ๋ธ์˜ ๊ตฌ์กฐ์— ๋”ฐ๋ผ ํ•„์š” ์—†์„ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
        #    ์˜ˆ: self.seq_len = self.model.seq_len

    def __call__(self, inputs, **kwargs):
        # ์ด ํ•จ์ˆ˜๋Š” API ์˜ˆ์ธก ์š”์ฒญ์ด ์˜ฌ ๋•Œ๋งˆ๋‹ค ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
        # ์‹ค์ œ ์˜ˆ์ธก ๋กœ์ง์ด ๋“ค์–ด๊ฐ€๋Š” ๋ถ€๋ถ„์ž…๋‹ˆ๋‹ค.

        # 1. ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ํŒŒ์‹ฑ
        # inputs๋Š” ๋ณดํ†ต ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ์˜ ํ…์ŠคํŠธ ๋˜๋Š” ๋ฐ”์ดํŠธ ๋ฐ์ดํ„ฐ๋กœ ๋“ค์–ด์˜ต๋‹ˆ๋‹ค.
        # JSON ํ˜•์‹์œผ๋กœ ์ž…๋ ฅ์„ ๋ฐ›์œผ๋ ค๋ฉด ์ถ”๊ฐ€์ ์ธ ์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
        # ์—ฌ๊ธฐ์„œ๋Š” ๊ฐ„๋‹จํžˆ inputs๊ฐ€ ์ˆซ์ž ๋ฆฌ์ŠคํŠธ ๋ฌธ์ž์—ด์ด๋ผ๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
        # ์˜ˆ: "500.1, 502.3, ..., 498.7"
        
        # ๋ฌธ์ž์—ด์„ ์ˆซ์ž ๋ฆฌ์ŠคํŠธ๋กœ ๋ณ€ํ™˜
        try:
            # ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ํŒŒ์‹ฑํ•˜๋Š” ๊ฐ€์žฅ ์ข‹์€ ๋ฐฉ๋ฒ•์€ JSON์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
            # ์˜ˆ: `json.loads(inputs[0])`
            # ์—ฌ๊ธฐ์„œ๋Š” ๊ฐ„๋‹จํ•œ ์˜ˆ์‹œ๋ฅผ ์œ„ํ•ด split์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
            input_list = [float(i) for i in inputs[0].split(',')]
            seq_len = 144 # ์ด ๊ฐ’์€ ์‹ค์ œ ๋ชจ๋ธ์˜ ์ž…๋ ฅ ๊ธธ์ด์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

            if len(input_list) != seq_len:
                 raise ValueError(f"Input must have {seq_len} items.")

        except Exception as e:
            # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ ์—๋Ÿฌ ๋ฉ”์‹œ์ง€๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
            return {"error": f"Invalid input format: {str(e)}"}, 400

        # 2. ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹(Tensor)์œผ๋กœ ๋ณ€ํ™˜
        input_array = np.array(input_list).reshape(-1, 1)
        scaled_input = self.scaler.transform(input_array)
        input_tensor = torch.from_numpy(scaled_input).float().unsqueeze(0).to(self.model.device)

        # 3. ๋ชจ๋ธ ์˜ˆ์ธก ์‹คํ–‰
        with torch.no_grad():
            # TimeXer ๋ชจ๋ธ์˜ forward ํ•จ์ˆ˜์— ํ•„์š”ํ•œ ๋ชจ๋“  ์ธ์ž๋ฅผ ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
            # ์˜ˆ: outputs = self.model(batch_x=input_tensor, batch_x_mark=...)
            # ์ด ๋ถ€๋ถ„์€ ๋ชจ๋ธ์˜ ์‹ค์ œ ์ฝ”๋“œ๋ฅผ ๋ณด๊ณ  ์ฑ„์›Œ์•ผ ํ•ฉ๋‹ˆ๋‹ค.
            # ์—ฌ๊ธฐ์„œ๋Š” input_tensor๋งŒ ํ•„์š”ํ•˜๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
            outputs = self.model(input_tensor)

        # 4. ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ํ›„์ฒ˜๋ฆฌํ•˜๊ณ  ์›๋ž˜ ์Šค์ผ€์ผ๋กœ ๋ณต์›
        prediction_scaled = outputs.detach().cpu().numpy().squeeze()
        prediction = self.scaler.inverse_transform(prediction_scaled.reshape(-1, 1))
        
        # 5. ์ตœ์ข… ๊ฒฐ๊ณผ๋ฅผ ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜
        # Hugging Face ํ•ธ๋“ค๋Ÿฌ๋Š” ๋ณดํ†ต ํ…์ŠคํŠธ๋‚˜ ๋ฐ”์ดํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
        # ๊ฒฐ๊ณผ๋ฅผ JSON ๋ฌธ์ž์—ด๋กœ ๋งŒ๋“ค์–ด ๋ฐ˜ํ™˜ํ•˜๋Š” ๊ฒƒ์ด ์ผ๋ฐ˜์ ์ž…๋‹ˆ๋‹ค.
        import json
        result_str = json.dumps({"prediction": prediction.flatten().tolist()})

        return [result_str]