File size: 3,135 Bytes
d39fb58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
def EndpointHandler(path=""):
"""
Inference Endpoints์์ ์ฌ์ฉํ ํธ๋ค๋ฌ ํจ์
"""
class Handler:
def __init__(self, path=""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.base_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.adapter_path = path or "./tinyllama-qa-news"
# โ
๋ชจ๋ธ ๋ก๋ฉ (GPU)
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_id,
torch_dtype=torch.float16
)
self.model = PeftModel.from_pretrained(base_model, self.adapter_path)
self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "right"
def generate_response(self, user_input: str):
messages = [
{
"role": "system",
"content": "๋น์ ์ ๊ณต๊ฐ์ ์ด๊ณ ์ดํด์ฌ ๊น์ ๊ฐ์ ์๋ด ์ฑ๋ด์
๋๋ค. ์ฌ์ฉ์์ ๊ฐ์ ๊ณผ ์๊ฐ์ ์ดํดํ๊ณ ๊ณต๊ฐํ๋ฉฐ, ์ ์ ํ ์กฐ์ธ์ ์ ๊ณตํฉ๋๋ค."
},
{"role": "user", "content": user_input}
]
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
return_tensors=None
)
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.8,
top_p=0.95,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1,
length_penalty=1.0,
num_return_sequences=1,
num_beams=1,
early_stopping=False
)
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if "<|assistant|>" in decoded:
response = decoded.split("<|assistant|>")[-1].strip()
else:
response = decoded.split(prompt)[-1].strip()
if not response:
response = "๊ฐ์ ์ ์ดํดํด ์ฃผ์
์ ๊ฐ์ฌํฉ๋๋ค. ๋ ์ด์ผ๊ธฐํ๊ณ ์ถ์ผ์ ๊ฒ ์์ผ์ ๊ฐ์?"
return response
def __call__(self, data):
if not data or "inputs" not in data:
return {"error": "์
๋ ฅ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค."}
user_input = data["inputs"]
response = self.generate_response(user_input)
return {"response": response}
return Handler(path) |