# server.py import os from typing import List from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification # ── 1) FastAPI 앱 생성 app = FastAPI( title="AGaRiCleaner Toxicity Detector (FastAPI)", description="FastAPI 기반 한국어 악플 탐지 모델 서버", version="1.0.0" ) # ── 2) 요청 스키마 정의 (Pydantic 모델) class TextsIn(BaseModel): data: List[str] # JSON 예시: { "data": ["문장1", "문장2", ...] } # ── 3) 모델 디렉터리 경로 (Space에서는 /app/detector 폴더가 된다) MODEL_DIR = "./detector" # ─── 여기부터 디버그 코드 ─────────────────────────────────────────────────────── print("===== DETECTOR 디렉터리 내부 상태 =====") for fname in sorted(os.listdir(MODEL_DIR)): fpath = os.path.join(MODEL_DIR, fname) try: size = os.path.getsize(fpath) except Exception: size = None preview = "" # JSON 파일인 경우, 첫 줄만 읽어서 preview에 담아둡니다. if fname.endswith(".json"): try: with open(fpath, "r", encoding="utf-8") as f: preview = f.readline().strip() except Exception: preview = "<읽기 실패>" print(f" • {fname:<25} size={size:<8} preview={preview[:60]!r}") print("========================================\n") # ─── 디버그 코드 끝 ───────────────────────────────────────────────────────────── # (이제부터 실제 토크나이저 로드) # ── 4) 디바이스 설정 (Mac MPS 지원 여부 확인) device = "mps" if torch.backends.mps.is_available() else "cpu" print(f"▶ 모델 추론 디바이스: {device}") # ── 5) 토크나이저와 모델 로드 try: tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) model.to(device) model.eval() print("✔ 모델 및 토크나이저 로드 완료") except Exception as e: print("✖ 모델 로드 실패:", e) raise e # ── 6) 악플 탐지 함수 정의 def detect_toxic(texts: List[str]) -> List[dict]: encoding = tokenizer( texts, padding=True, truncation=True, return_tensors="pt", max_length=128 ) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits probs = torch.softmax(logits, dim=-1).cpu().tolist() results = [] for i, text in enumerate(texts): score_1 = probs[i][1] label = 1 if score_1 >= 0.5 else 0 results.append({ "text": text, "label": label, "score": round(score_1, 6) }) return results # ── 7) POST /predict 엔드포인트 정의 @app.post("/predict", summary="텍스트 목록을 입력받아 악플 여부(label, score) 반환") async def predict_endpoint(payload: TextsIn): texts = payload.data if not isinstance(texts, list) or len(texts) == 0: raise HTTPException(status_code=400, detail="‘data’ 필드에 최소 1개 이상의 문자열이 있어야 합니다.") try: output = detect_toxic(texts) return output except Exception as e: raise HTTPException(status_code=500, detail=f"모델 추론 중 오류 발생: {e}")