| |
|
|
| from typing import List |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| |
| app = FastAPI( |
| title="AGaRiCleaner Toxicity Detector (FastAPI)", |
| description="FastAPI κΈ°λ° νκ΅μ΄ μ
ν νμ§ λͺ¨λΈ μλ²", |
| version="1.0.0" |
| ) |
|
|
| |
| class TextsIn(BaseModel): |
| data: List[str] |
|
|
| |
| MODEL_DIR = "./detector" |
|
|
| |
| device = "mps" if torch.backends.mps.is_available() else "cpu" |
| print(f"βΆ λͺ¨λΈ μΆλ‘ λλ°μ΄μ€: {device}") |
|
|
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) |
| model.to(device) |
| model.eval() |
| print("β λͺ¨λΈ λ° ν ν¬λμ΄μ λ‘λ μλ£") |
| except Exception as e: |
| print("β λͺ¨λΈ λ‘λ μ€ν¨:", e) |
| raise e |
|
|
| |
| def detect_toxic(texts: List[str]) -> List[dict]: |
| encoding = tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| return_tensors="pt", |
| max_length=128 |
| ) |
| input_ids = encoding["input_ids"].to(device) |
| attention_mask = encoding["attention_mask"].to(device) |
|
|
| with torch.no_grad(): |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
| logits = outputs.logits |
| probs = torch.softmax(logits, dim=-1).cpu().tolist() |
|
|
| results = [] |
| for i, text in enumerate(texts): |
| score_1 = probs[i][1] |
| label = 1 if score_1 >= 0.5 else 0 |
| results.append({ |
| "text": text, |
| "label": label, |
| "score": round(score_1, 6) |
| }) |
| return results |
|
|
| |
| @app.post("/predict", summary="ν
μ€νΈ λͺ©λ‘μ μ
λ ₯λ°μ μ
ν μ¬λΆ(label, score) λ°ν") |
| async def predict_endpoint(payload: TextsIn): |
| texts = payload.data |
| if not isinstance(texts, list) or len(texts) == 0: |
| raise HTTPException(status_code=400, detail="βdataβ νλμ μ΅μ 1κ° μ΄μμ λ¬Έμμ΄μ΄ μμ΄μΌ ν©λλ€.") |
| try: |
| output = detect_toxic(texts) |
| return output |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"λͺ¨λΈ μΆλ‘ μ€ μ€λ₯ λ°μ: {e}") |
|
|