| import os
|
| import json
|
| import torch
|
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
|
|
|
|
| def model_fn(model_dir):
|
| """
|
| SageMaker๊ฐ ๋ชจ๋ธ์ ๋ก๋ํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์
|
|
|
| Args:
|
| model_dir (str): ๋ชจ๋ธ ํ์ผ์ด ์ ์ฅ๋ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
|
|
| Returns:
|
| dict: ๋ชจ๋ธ, ํ ํฌ๋์ด์ , ์ค์ ๋ฑ์ ํฌํจํ ๋์
๋๋ฆฌ
|
| """
|
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
| config_path = os.path.join(model_dir, "config.json")
|
| config = AutoConfig.from_pretrained(config_path)
|
|
|
| print(f"Loading model from {model_dir}")
|
| print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
|
|
|
|
|
| label_map = {}
|
| label_map_path = os.path.join(model_dir, "label_map.json")
|
| if os.path.exists(label_map_path):
|
| with open(label_map_path, 'r', encoding='utf-8') as f:
|
| label_map = json.load(f)
|
| print(f"Loaded label map from {label_map_path}")
|
| else:
|
| print("No label map found. Using numeric indices as labels.")
|
|
|
|
|
| model = AutoModelForSequenceClassification.from_pretrained(
|
| model_dir,
|
| config=config,
|
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
| )
|
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| model = model.to(device)
|
| model.eval()
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
|
| return {
|
| "model": model,
|
| "tokenizer": tokenizer,
|
| "config": config,
|
| "device": device,
|
| "label_map": label_map
|
| }
|
|
|
|
|
| def input_fn(request_body, request_content_type):
|
| """
|
| SageMaker๊ฐ ์์ฒญ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์
|
|
|
| Args:
|
| request_body: ์์ฒญ ๋ณธ๋ฌธ ๋ฐ์ดํฐ
|
| request_content_type (str): ์์ฒญ ์ฝํ
์ธ ํ์
|
|
|
| Returns:
|
| dict: ์ฒ๋ฆฌ๋ ์
๋ ฅ ๋ฐ์ดํฐ
|
| """
|
| if request_content_type == "application/json":
|
| input_data = json.loads(request_body)
|
|
|
|
|
| if isinstance(input_data, str):
|
| return {"text": input_data}
|
|
|
| return input_data
|
|
|
| elif request_content_type == "text/plain":
|
|
|
| return {"text": request_body.decode('utf-8')}
|
|
|
| else:
|
| raise ValueError(f"์ง์๋์ง ์๋ ์ฝํ
์ธ ํ์
: {request_content_type}")
|
|
|
|
|
| def predict_fn(input_data, model_dict):
|
| """
|
| SageMaker๊ฐ ๋ชจ๋ธ ์์ธก์ ์ํํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์
|
|
|
| Args:
|
| input_data (dict): ์ฒ๋ฆฌ๋ ์
๋ ฅ ๋ฐ์ดํฐ
|
| model_dict (dict): model_fn์์ ๋ฐํํ ๋ชจ๋ธ ์ ๋ณด
|
|
|
| Returns:
|
| dict: ์์ธก ๊ฒฐ๊ณผ
|
| """
|
| model = model_dict["model"]
|
| tokenizer = model_dict["tokenizer"]
|
| device = model_dict["device"]
|
| label_map = model_dict["label_map"]
|
|
|
|
|
| if "text" in input_data:
|
| text = input_data["text"]
|
| else:
|
| raise ValueError("์
๋ ฅ ๋ฐ์ดํฐ์ 'text' ํ๋๊ฐ ์์ต๋๋ค")
|
|
|
|
|
| max_length = input_data.get("max_length", 512)
|
| padding = input_data.get("padding", "max_length")
|
| truncation = input_data.get("truncation", True)
|
|
|
|
|
| inputs = tokenizer(
|
| text,
|
| return_tensors="pt",
|
| padding=padding,
|
| truncation=truncation,
|
| max_length=max_length
|
| )
|
|
|
|
|
| inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
|
|
|
|
|
| with torch.no_grad():
|
| outputs = model(**inputs)
|
| logits = outputs.logits
|
| probabilities = torch.softmax(logits, dim=1)
|
|
|
|
|
| if logits.shape[1] == 2:
|
| positive_prob = probabilities[0, 1].item()
|
| negative_prob = probabilities[0, 0].item()
|
| prediction = 1 if positive_prob > 0.5 else 0
|
|
|
| result = {
|
| "prediction": prediction,
|
| "positive_probability": positive_prob,
|
| "negative_probability": negative_prob
|
| }
|
|
|
|
|
| if label_map:
|
| pred_label = str(prediction)
|
| if pred_label in label_map:
|
| result["label"] = label_map[pred_label]
|
|
|
|
|
| else:
|
| predictions = torch.argmax(probabilities, dim=1).cpu().numpy().tolist()
|
| probabilities = probabilities.cpu().numpy().tolist()[0]
|
|
|
| result = {
|
| "prediction": predictions[0],
|
| "probabilities": probabilities,
|
| }
|
|
|
|
|
| if label_map:
|
| pred_label = str(predictions[0])
|
| if pred_label in label_map:
|
| result["label"] = label_map[pred_label]
|
|
|
|
|
| result["label_probabilities"] = {
|
| label_map.get(str(idx), str(idx)): prob
|
| for idx, prob in enumerate(probabilities)
|
| }
|
|
|
| return result
|
|
|
|
|
| def output_fn(prediction, response_content_type):
|
| """
|
| SageMaker๊ฐ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ์๋ต ํ์์ผ๋ก ๋ณํํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์
|
|
|
| Args:
|
| prediction: predict_fn์์ ๋ฐํํ ์์ธก ๊ฒฐ๊ณผ
|
| response_content_type (str): ์ํ๋ ์๋ต ์ฝํ
์ธ ํ์
|
|
|
| Returns:
|
| str: ์ง๋ ฌํ๋ ์์ธก ๊ฒฐ๊ณผ
|
| """
|
| if response_content_type == "application/json":
|
| return json.dumps(prediction, ensure_ascii=False)
|
| else:
|
| raise ValueError(f"์ง์๋์ง ์๋ ์ฝํ
์ธ ํ์
: {response_content_type}") |