File size: 3,135 Bytes
d39fb58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

def EndpointHandler(path=""):
    """
    Inference Endpoints์—์„œ ์‚ฌ์šฉํ•  ํ•ธ๋“ค๋Ÿฌ ํ•จ์ˆ˜
    """
    class Handler:
        def __init__(self, path=""):
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.base_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
            self.adapter_path = path or "./tinyllama-qa-news"

            # โœ… ๋ชจ๋ธ ๋กœ๋”ฉ (GPU)
            base_model = AutoModelForCausalLM.from_pretrained(
                self.base_model_id,
                torch_dtype=torch.float16
            )
            self.model = PeftModel.from_pretrained(base_model, self.adapter_path)
            self.model.to(self.device)
            self.model.eval()

            self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.padding_side = "right"

        def generate_response(self, user_input: str):
            messages = [
                {
                    "role": "system",
                    "content": "๋‹น์‹ ์€ ๊ณต๊ฐ์ ์ด๊ณ  ์ดํ•ด์‹ฌ ๊นŠ์€ ๊ฐ์ • ์ƒ๋‹ด ์ฑ—๋ด‡์ž…๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž์˜ ๊ฐ์ •๊ณผ ์ƒ๊ฐ์„ ์ดํ•ดํ•˜๊ณ  ๊ณต๊ฐํ•˜๋ฉฐ, ์ ์ ˆํ•œ ์กฐ์–ธ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค."
                },
                {"role": "user", "content": user_input}
            ]

            prompt = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                return_tensors=None
            )

            inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=256,
                    temperature=0.8,
                    top_p=0.95,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    repetition_penalty=1.1,
                    length_penalty=1.0,
                    num_return_sequences=1,
                    num_beams=1,
                    early_stopping=False
                )

            decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            if "<|assistant|>" in decoded:
                response = decoded.split("<|assistant|>")[-1].strip()
            else:
                response = decoded.split(prompt)[-1].strip()

            if not response:
                response = "๊ฐ์ •์„ ์ดํ•ดํ•ด ์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ๋” ์ด์•ผ๊ธฐํ•˜๊ณ  ์‹ถ์œผ์‹  ๊ฒŒ ์žˆ์œผ์‹ ๊ฐ€์š”?"

            return response

        def __call__(self, data):
            if not data or "inputs" not in data:
                return {"error": "์ž…๋ ฅ ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค."}

            user_input = data["inputs"]
            response = self.generate_response(user_input)
            return {"response": response}
    
    return Handler(path)