File size: 3,761 Bytes
f4d60e7 ec4f96f f4d60e7 b5d9ecd f4d60e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | # server.py
import os
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ββ 1) FastAPI μ± μμ±
app = FastAPI(
title="AGaRiCleaner Toxicity Detector (FastAPI)",
description="FastAPI κΈ°λ° νκ΅μ΄ μ
ν νμ§ λͺ¨λΈ μλ²",
version="1.0.0"
)
# ββ 2) μμ² μ€ν€λ§ μ μ (Pydantic λͺ¨λΈ)
class TextsIn(BaseModel):
data: List[str] # JSON μμ: { "data": ["λ¬Έμ₯1", "λ¬Έμ₯2", ...] }
# ββ 3) λͺ¨λΈ λλ ν°λ¦¬ κ²½λ‘ (Spaceμμλ /app/detector ν΄λκ° λλ€)
MODEL_DIR = "./detector"
# βββ μ¬κΈ°λΆν° λλ²κ·Έ μ½λ βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("===== DETECTOR λλ ν°λ¦¬ λ΄λΆ μν =====")
for fname in sorted(os.listdir(MODEL_DIR)):
fpath = os.path.join(MODEL_DIR, fname)
try:
size = os.path.getsize(fpath)
except Exception:
size = None
preview = ""
# JSON νμΌμΈ κ²½μ°, 첫 μ€λ§ μ½μ΄μ previewμ λ΄μλ‘λλ€.
if fname.endswith(".json"):
try:
with open(fpath, "r", encoding="utf-8") as f:
preview = f.readline().strip()
except Exception:
preview = "<μ½κΈ° μ€ν¨>"
print(f" β’ {fname:<25} size={size:<8} preview={preview[:60]!r}")
print("========================================\n")
# βββ λλ²κ·Έ μ½λ λ βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# (μ΄μ λΆν° μ€μ ν ν¬λμ΄μ λ‘λ)
# ββ 4) λλ°μ΄μ€ μ€μ (Mac MPS μ§μ μ¬λΆ νμΈ)
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"βΆ λͺ¨λΈ μΆλ‘ λλ°μ΄μ€: {device}")
# ββ 5) ν ν¬λμ΄μ μ λͺ¨λΈ λ‘λ
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
model.to(device)
model.eval()
print("β λͺ¨λΈ λ° ν ν¬λμ΄μ λ‘λ μλ£")
except Exception as e:
print("β λͺ¨λΈ λ‘λ μ€ν¨:", e)
raise e
# ββ 6) μ
ν νμ§ ν¨μ μ μ
def detect_toxic(texts: List[str]) -> List[dict]:
encoding = tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt",
max_length=128
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1).cpu().tolist()
results = []
for i, text in enumerate(texts):
score_1 = probs[i][1]
label = 1 if score_1 >= 0.5 else 0
results.append({
"text": text,
"label": label,
"score": round(score_1, 6)
})
return results
# ββ 7) POST /predict μλν¬μΈνΈ μ μ
@app.post("/predict", summary="ν
μ€νΈ λͺ©λ‘μ μ
λ ₯λ°μ μ
ν μ¬λΆ(label, score) λ°ν")
async def predict_endpoint(payload: TextsIn):
texts = payload.data
if not isinstance(texts, list) or len(texts) == 0:
raise HTTPException(status_code=400, detail="βdataβ νλμ μ΅μ 1κ° μ΄μμ λ¬Έμμ΄μ΄ μμ΄μΌ ν©λλ€.")
try:
output = detect_toxic(texts)
return output
except Exception as e:
raise HTTPException(status_code=500, detail=f"λͺ¨λΈ μΆλ‘ μ€ μ€λ₯ λ°μ: {e}")
|