| | |
| | import os |
| | from typing import List |
| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| |
|
| | |
| | app = FastAPI( |
| | title="AGaRiCleaner Toxicity Detector (FastAPI)", |
| | description="FastAPI κΈ°λ° νκ΅μ΄ μ
ν νμ§ λͺ¨λΈ μλ²", |
| | version="1.0.0" |
| | ) |
| |
|
| | |
| | class TextsIn(BaseModel): |
| | data: List[str] |
| |
|
| | |
| | MODEL_DIR = "./detector" |
| |
|
| | |
| | print("===== DETECTOR λλ ν°λ¦¬ λ΄λΆ μν =====") |
| | for fname in sorted(os.listdir(MODEL_DIR)): |
| | fpath = os.path.join(MODEL_DIR, fname) |
| | try: |
| | size = os.path.getsize(fpath) |
| | except Exception: |
| | size = None |
| | preview = "" |
| | |
| | if fname.endswith(".json"): |
| | try: |
| | with open(fpath, "r", encoding="utf-8") as f: |
| | preview = f.readline().strip() |
| | except Exception: |
| | preview = "<μ½κΈ° μ€ν¨>" |
| | print(f" β’ {fname:<25} size={size:<8} preview={preview[:60]!r}") |
| | print("========================================\n") |
| | |
| |
|
| | |
| |
|
| | |
| | device = "mps" if torch.backends.mps.is_available() else "cpu" |
| | print(f"βΆ λͺ¨λΈ μΆλ‘ λλ°μ΄μ€: {device}") |
| |
|
| | |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) |
| | model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) |
| | model.to(device) |
| | model.eval() |
| | print("β λͺ¨λΈ λ° ν ν¬λμ΄μ λ‘λ μλ£") |
| | except Exception as e: |
| | print("β λͺ¨λΈ λ‘λ μ€ν¨:", e) |
| | raise e |
| |
|
| | |
| | def detect_toxic(texts: List[str]) -> List[dict]: |
| | encoding = tokenizer( |
| | texts, |
| | padding=True, |
| | truncation=True, |
| | return_tensors="pt", |
| | max_length=128 |
| | ) |
| | input_ids = encoding["input_ids"].to(device) |
| | attention_mask = encoding["attention_mask"].to(device) |
| |
|
| | with torch.no_grad(): |
| | outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
| | logits = outputs.logits |
| | probs = torch.softmax(logits, dim=-1).cpu().tolist() |
| |
|
| | results = [] |
| | for i, text in enumerate(texts): |
| | score_1 = probs[i][1] |
| | label = 1 if score_1 >= 0.5 else 0 |
| | results.append({ |
| | "text": text, |
| | "label": label, |
| | "score": round(score_1, 6) |
| | }) |
| | return results |
| |
|
| | |
| | @app.post("/predict", summary="ν
μ€νΈ λͺ©λ‘μ μ
λ ₯λ°μ μ
ν μ¬λΆ(label, score) λ°ν") |
| | async def predict_endpoint(payload: TextsIn): |
| | texts = payload.data |
| | if not isinstance(texts, list) or len(texts) == 0: |
| | raise HTTPException(status_code=400, detail="βdataβ νλμ μ΅μ 1κ° μ΄μμ λ¬Έμμ΄μ΄ μμ΄μΌ ν©λλ€.") |
| | try: |
| | output = detect_toxic(texts) |
| | return output |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"λͺ¨λΈ μΆλ‘ μ€ μ€λ₯ λ°μ: {e}") |
| |
|