simcse_chat / handler.py
eunyuOffice's picture
Upload handler.py with huggingface_hub
db9b4f9 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
def EndpointHandler(path=""):
"""
Inference Endpoints์—์„œ ์‚ฌ์šฉํ•  ํ•ธ๋“ค๋Ÿฌ ํ•จ์ˆ˜
"""
class Handler:
def __init__(self, path=""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.base_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.adapter_path = path or "./tinyllama-qa-news"
# โœ… ๋ชจ๋ธ ๋กœ๋”ฉ (GPU)
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_id,
torch_dtype=torch.float16
)
self.model = PeftModel.from_pretrained(base_model, self.adapter_path)
self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "right"
def generate_response(self, user_input: str):
messages = [
{
"role": "system",
"content": "๋‹น์‹ ์€ ๊ณต๊ฐ์ ์ด๊ณ  ์ดํ•ด์‹ฌ ๊นŠ์€ ๊ฐ์ • ์ƒ๋‹ด ์ฑ—๋ด‡์ž…๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž์˜ ๊ฐ์ •๊ณผ ์ƒ๊ฐ์„ ์ดํ•ดํ•˜๊ณ  ๊ณต๊ฐํ•˜๋ฉฐ, ์ ์ ˆํ•œ ์กฐ์–ธ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค."
},
{"role": "user", "content": user_input}
]
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
return_tensors=None
)
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.8,
top_p=0.95,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1,
length_penalty=1.0,
num_return_sequences=1,
num_beams=1,
early_stopping=False
)
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if "<|assistant|>" in decoded:
response = decoded.split("<|assistant|>")[-1].strip()
else:
response = decoded.split(prompt)[-1].strip()
if not response:
response = "๊ฐ์ •์„ ์ดํ•ดํ•ด ์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ๋” ์ด์•ผ๊ธฐํ•˜๊ณ  ์‹ถ์œผ์‹  ๊ฒŒ ์žˆ์œผ์‹ ๊ฐ€์š”?"
return response
def __call__(self, data):
if not data or "inputs" not in data:
return {"error": "์ž…๋ ฅ ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค."}
user_input = data["inputs"]
response = self.generate_response(user_input)
return {"response": response}
return Handler(path)