eunyuOffice commited on
Commit
db9b4f9
ยท
verified ยท
1 Parent(s): 6fdcbc8

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +82 -0
handler.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
+
5
+ def EndpointHandler(path=""):
6
+ """
7
+ Inference Endpoints์—์„œ ์‚ฌ์šฉํ•  ํ•ธ๋“ค๋Ÿฌ ํ•จ์ˆ˜
8
+ """
9
+ class Handler:
10
+ def __init__(self, path=""):
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ self.base_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
13
+ self.adapter_path = path or "./tinyllama-qa-news"
14
+
15
+ # โœ… ๋ชจ๋ธ ๋กœ๋”ฉ (GPU)
16
+ base_model = AutoModelForCausalLM.from_pretrained(
17
+ self.base_model_id,
18
+ torch_dtype=torch.float16
19
+ )
20
+ self.model = PeftModel.from_pretrained(base_model, self.adapter_path)
21
+ self.model.to(self.device)
22
+ self.model.eval()
23
+
24
+ self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
25
+ self.tokenizer.pad_token = self.tokenizer.eos_token
26
+ self.tokenizer.padding_side = "right"
27
+
28
+ def generate_response(self, user_input: str):
29
+ messages = [
30
+ {
31
+ "role": "system",
32
+ "content": "๋‹น์‹ ์€ ๊ณต๊ฐ์ ์ด๊ณ  ์ดํ•ด์‹ฌ ๊นŠ์€ ๊ฐ์ • ์ƒ๋‹ด ์ฑ—๋ด‡์ž…๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž์˜ ๊ฐ์ •๊ณผ ์ƒ๊ฐ์„ ์ดํ•ดํ•˜๊ณ  ๊ณต๊ฐํ•˜๋ฉฐ, ์ ์ ˆํ•œ ์กฐ์–ธ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค."
33
+ },
34
+ {"role": "user", "content": user_input}
35
+ ]
36
+
37
+ prompt = self.tokenizer.apply_chat_template(
38
+ messages,
39
+ tokenize=False,
40
+ add_generation_prompt=True,
41
+ return_tensors=None
42
+ )
43
+
44
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
45
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
46
+
47
+ with torch.no_grad():
48
+ outputs = self.model.generate(
49
+ **inputs,
50
+ max_new_tokens=256,
51
+ temperature=0.8,
52
+ top_p=0.95,
53
+ do_sample=True,
54
+ pad_token_id=self.tokenizer.eos_token_id,
55
+ repetition_penalty=1.1,
56
+ length_penalty=1.0,
57
+ num_return_sequences=1,
58
+ num_beams=1,
59
+ early_stopping=False
60
+ )
61
+
62
+ decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
63
+
64
+ if "<|assistant|>" in decoded:
65
+ response = decoded.split("<|assistant|>")[-1].strip()
66
+ else:
67
+ response = decoded.split(prompt)[-1].strip()
68
+
69
+ if not response:
70
+ response = "๊ฐ์ •์„ ์ดํ•ดํ•ด ์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ๋” ์ด์•ผ๊ธฐํ•˜๊ณ  ์‹ถ์œผ์‹  ๊ฒŒ ์žˆ์œผ์‹ ๊ฐ€์š”?"
71
+
72
+ return response
73
+
74
+ def __call__(self, data):
75
+ if not data or "inputs" not in data:
76
+ return {"error": "์ž…๋ ฅ ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค."}
77
+
78
+ user_input = data["inputs"]
79
+ response = self.generate_response(user_input)
80
+ return {"response": response}
81
+
82
+ return Handler(path)