hiddenFront's picture
Update app.py
3dd80ec verified
raw
history blame
3.1 kB
from fastapi import FastAPI, Request
from transformers import BertModel, BertForSequenceClassification, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import sys
import psutil
app = FastAPI()
device = torch.device("cpu")
# category.pkl ๋กœ๋“œ
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("โœ… category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
print("โŒ Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
sys.exit(1)
# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
class CustomClassifier(torch.nn.Module):
def __init__(self):
super().__init__()
# ์ •์˜ํ–ˆ๋˜ ๊ตฌ์กฐ ๊ทธ๋Œ€๋กœ ๋ณต์›ํ•ด์•ผ ํ•จ
self.bert = BertModel.from_pretrained("skt/kobert-base-v1")
self.classifier = torch.nn.Linear(768, len(category))
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
pooled_output = outputs[1] # CLS ํ† ํฐ
return self.classifier(pooled_output)
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"
# ๋ฉ”๋ชจ๋ฆฌ ์ธก์ • ์ „
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before:.2f} MB")
# ๋ชจ๋ธ ๋กœ๋“œ
try:
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"โœ… ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต: {model_path}")
state_dict = torch.load(model_path, map_location=device)
model = BertForSequenceClassification.from_pretrained(
"skt/kobert-base-v1",
num_labels=len(category),
state_dict=state_dict,
)
model.to(device)
model.eval()
print("โœ… ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ์ค€๋น„ ์™„๋ฃŒ.")
except Exception as e:
print(f"โŒ Error: ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
sys.exit(1)
@app.get("/")
def root(request: Request):
client_host = request.client.host
client_port = request.client.port
return {
"message": "Text Classification API is running!",
"client_ip": client_host,
"client_port": client_port
}
# ์˜ˆ์ธก API
@app.post("/predict")
async def predict_api(request: Request):
data = await request.json()
text = data.get("text")
print("request date", data);
if not text:
return {"error": "No text provided", "classification": "null"}
encoded = tokenizer.encode_plus(
text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
)
with torch.no_grad():
outputs = model(**encoded)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted = torch.argmax(probs, dim=1).item()
label = list(category.keys())[predicted]
return {"text": text, "classification": label}