hiddenFront's picture
Update app.py
6ba018e verified
raw
history blame
3.09 kB
from fastapi import FastAPI, Request
from transformers import BertForSequenceClassification, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import sys
import psutil
app = FastAPI()
device = torch.device("cpu")
# category.pkl ๋กœ๋“œ
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("โœ… category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
print("โŒ Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
sys.exit(1)
# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
class CustomClassifier(torch.nn.Module):
def __init__(self):
super().__init__()
# ์ •์˜ํ–ˆ๋˜ ๊ตฌ์กฐ ๊ทธ๋Œ€๋กœ ๋ณต์›ํ•ด์•ผ ํ•จ
self.bert = BertModel.from_pretrained("skt/kobert-base-v1")
self.classifier = torch.nn.Linear(768, len(category))
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
pooled_output = outputs[1] # CLS ํ† ํฐ
return self.classifier(pooled_output)
model = CustomClassifier()
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"
# ๋ฉ”๋ชจ๋ฆฌ ์ธก์ • ์ „
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before:.2f} MB")
# ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋‹ค์šด๋กœ๋“œ
try:
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"โœ… ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต: {model_path}")
mem_after_dl = process.memory_info().rss / (1024 * 1024)
print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_dl:.2f} MB")
# state_dict ๋กœ๋“œ
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
mem_after_load = process.memory_info().rss / (1024 * 1024)
print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_load:.2f} MB")
print("โœ… ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ์ค€๋น„ ์™„๋ฃŒ.")
except Exception as e:
print(f"โŒ Error: ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
sys.exit(1)
# ์˜ˆ์ธก API
@app.post("/predict")
async def predict_api(request: Request):
data = await request.json()
text = data.get("text")
if not text:
return {"error": "No text provided", "classification": "null"}
encoded = tokenizer.encode_plus(
text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
)
with torch.no_grad():
outputs = model(**encoded)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted = torch.argmax(probs, dim=1).item()
label = list(category.keys())[predicted]
return {"text": text, "classification": label}