File size: 3,761 Bytes
f4d60e7
ec4f96f
f4d60e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5d9ecd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4d60e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# server.py
import os
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ── 1) FastAPI μ•± 생성
app = FastAPI(
    title="AGaRiCleaner Toxicity Detector (FastAPI)",
    description="FastAPI 기반 ν•œκ΅­μ–΄ μ•…ν”Œ 탐지 λͺ¨λΈ μ„œλ²„",
    version="1.0.0"
)

# ── 2) μš”μ²­ μŠ€ν‚€λ§ˆ μ •μ˜ (Pydantic λͺ¨λΈ)
class TextsIn(BaseModel):
    data: List[str]  # JSON μ˜ˆμ‹œ: { "data": ["λ¬Έμž₯1", "λ¬Έμž₯2", ...] }

# ── 3) λͺ¨λΈ 디렉터리 경둜 (Spaceμ—μ„œλŠ” /app/detector 폴더가 λœλ‹€)
MODEL_DIR = "./detector"

# ─── μ—¬κΈ°λΆ€ν„° 디버그 μ½”λ“œ ───────────────────────────────────────────────────────
print("===== DETECTOR 디렉터리 λ‚΄λΆ€ μƒνƒœ =====")
for fname in sorted(os.listdir(MODEL_DIR)):
    fpath = os.path.join(MODEL_DIR, fname)
    try:
        size = os.path.getsize(fpath)
    except Exception:
        size = None
    preview = ""
    # JSON 파일인 경우, 첫 μ€„λ§Œ μ½μ–΄μ„œ preview에 λ‹΄μ•„λ‘‘λ‹ˆλ‹€.
    if fname.endswith(".json"):
        try:
            with open(fpath, "r", encoding="utf-8") as f:
                preview = f.readline().strip()
        except Exception:
            preview = "<읽기 μ‹€νŒ¨>"
    print(f" β€’ {fname:<25} size={size:<8} preview={preview[:60]!r}")
print("========================================\n")
# ─── 디버그 μ½”λ“œ 끝 ─────────────────────────────────────────────────────────────

# (μ΄μ œλΆ€ν„° μ‹€μ œ ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ)

# ── 4) λ””λ°”μ΄μŠ€ μ„€μ • (Mac MPS 지원 μ—¬λΆ€ 확인)
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"β–Ά λͺ¨λΈ μΆ”λ‘  λ””λ°”μ΄μŠ€: {device}")

# ── 5) ν† ν¬λ‚˜μ΄μ €μ™€ λͺ¨λΈ λ‘œλ“œ
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
    model.to(device)
    model.eval()
    print("βœ” λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ μ™„λ£Œ")
except Exception as e:
    print("βœ– λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨:", e)
    raise e

# ── 6) μ•…ν”Œ 탐지 ν•¨μˆ˜ μ •μ˜
def detect_toxic(texts: List[str]) -> List[dict]:
    encoding = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt",
        max_length=128
    )
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=-1).cpu().tolist()

    results = []
    for i, text in enumerate(texts):
        score_1 = probs[i][1]
        label = 1 if score_1 >= 0.5 else 0
        results.append({
            "text": text,
            "label": label,
            "score": round(score_1, 6)
        })
    return results

# ── 7) POST /predict μ—”λ“œν¬μΈνŠΈ μ •μ˜
@app.post("/predict", summary="ν…μŠ€νŠΈ λͺ©λ‘μ„ μž…λ ₯λ°›μ•„ μ•…ν”Œ μ—¬λΆ€(label, score) λ°˜ν™˜")
async def predict_endpoint(payload: TextsIn):
    texts = payload.data
    if not isinstance(texts, list) or len(texts) == 0:
        raise HTTPException(status_code=400, detail="β€˜data’ ν•„λ“œμ— μ΅œμ†Œ 1개 μ΄μƒμ˜ λ¬Έμžμ—΄μ΄ μžˆμ–΄μ•Ό ν•©λ‹ˆλ‹€.")
    try:
        output = detect_toxic(texts)
        return output
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"λͺ¨λΈ μΆ”λ‘  쀑 였λ₯˜ λ°œμƒ: {e}")