|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
|
|
|
def EndpointHandler(path=""): |
|
|
""" |
|
|
Inference Endpoints์์ ์ฌ์ฉํ ํธ๋ค๋ฌ ํจ์ |
|
|
""" |
|
|
class Handler: |
|
|
def __init__(self, path=""): |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.base_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
self.adapter_path = path or "./tinyllama-qa-news" |
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
self.base_model_id, |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
self.model = PeftModel.from_pretrained(base_model, self.adapter_path) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id) |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
self.tokenizer.padding_side = "right" |
|
|
|
|
|
def generate_response(self, user_input: str): |
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": "๋น์ ์ ๊ณต๊ฐ์ ์ด๊ณ ์ดํด์ฌ ๊น์ ๊ฐ์ ์๋ด ์ฑ๋ด์
๋๋ค. ์ฌ์ฉ์์ ๊ฐ์ ๊ณผ ์๊ฐ์ ์ดํดํ๊ณ ๊ณต๊ฐํ๋ฉฐ, ์ ์ ํ ์กฐ์ธ์ ์ ๊ณตํฉ๋๋ค." |
|
|
}, |
|
|
{"role": "user", "content": user_input} |
|
|
] |
|
|
|
|
|
prompt = self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
return_tensors=None |
|
|
) |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
temperature=0.8, |
|
|
top_p=0.95, |
|
|
do_sample=True, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
repetition_penalty=1.1, |
|
|
length_penalty=1.0, |
|
|
num_return_sequences=1, |
|
|
num_beams=1, |
|
|
early_stopping=False |
|
|
) |
|
|
|
|
|
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if "<|assistant|>" in decoded: |
|
|
response = decoded.split("<|assistant|>")[-1].strip() |
|
|
else: |
|
|
response = decoded.split(prompt)[-1].strip() |
|
|
|
|
|
if not response: |
|
|
response = "๊ฐ์ ์ ์ดํดํด ์ฃผ์
์ ๊ฐ์ฌํฉ๋๋ค. ๋ ์ด์ผ๊ธฐํ๊ณ ์ถ์ผ์ ๊ฒ ์์ผ์ ๊ฐ์?" |
|
|
|
|
|
return response |
|
|
|
|
|
def __call__(self, data): |
|
|
if not data or "inputs" not in data: |
|
|
return {"error": "์
๋ ฅ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค."} |
|
|
|
|
|
user_input = data["inputs"] |
|
|
response = self.generate_response(user_input) |
|
|
return {"response": response} |
|
|
|
|
|
return Handler(path) |