from fastapi import FastAPI, Depends from pydantic import BaseModel, Field from typing import Dict, List, Optional import models from database import engine, SessionLocal from sqlalchemy.orm import Session from models import Article, AnalysisResult from crossref_model import get_crossref_score_and_reason from mismatch_model import calculate_mismatch_score from aggro_model import get_aggro_score from fastapi.middleware.cors import CORSMiddleware models.Base.metadata.create_all(bind=engine) from database import SessionLocal # database.py에서 정의한 SessionLocal 가져오기 # DB 세션을 열고 닫는 Dependency 함수 정의 def get_db(): db = SessionLocal() try: yield db finally: db.close() # --- 1. API 명세서 (Request/Response) --- # Pydantic 모델, API 명세 초안 class ArticleRequest(BaseModel): """프론트가 보낼 요청 형식""" article_title: str = Field(..., example="기사 제목이 여기 들어갑니다") article_body: str = Field(..., example="기사 본문 텍스트...") class FoundURL(BaseModel): """SBERT가 검증한 URL 객체""" url: str similarity: float = Field(..., example=0.85) class ScoreBreakdown(BaseModel): """개별 점수 상세내역 형식""" score: float = Field(..., example=0.95) reason: str = Field(..., example="'충격' 키워드 사용") recommendation: str = Field(..., example="중립적인 단어로 수정하세요.") found_urls: Optional[List[FoundURL]] = None class AnalysisResponse(BaseModel): """백엔드->프론트 최종 응답 형식""" final_risk_score: float = Field(..., example=0.82) final_risk_level: str = Field(..., example="위험") breakdown: Dict[str, ScoreBreakdown] # --- 2. FastAPI 앱 생성 --- app = FastAPI() origins = [ "http://localhost:3000", # 프론트엔드 개발자가 요청한 주소 (로컬 테스트용) "https://my-frontend-app.vercel.app", # (나중에) 프론트가 배포될 실제 주소도 미리 넣으면 좋음 "*" # 귀찮으면 이걸 주석 해제해서 다 허용해도 됩니다. ] app.add_middleware( CORSMiddleware, allow_origins=origins, # 위에서 만든 리스트를 적용 allow_credentials=True, allow_methods=["*"], # 모든 HTTP 메서드 허용 (GET, POST 등) allow_headers=["*"], # 모든 헤더 허용 ) # --- 3. API 엔드포인트 --- @app.post("/api/v1/analyze", response_model=AnalysisResponse) def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)): """기사 분석 API""" # 1. AggroScore aggro_result = get_aggro_score(request.article_title) real_aggro = ScoreBreakdown( score=aggro_result["score"], reason=aggro_result["reason"], recommendation=aggro_result["recommendation"], found_urls=None ) # 2. mismatch mismatch_result = calculate_mismatch_score(request.article_title, request.article_body) real_mismatch = ScoreBreakdown( score=mismatch_result["score"], reason=mismatch_result["reason"], recommendation=mismatch_result["recommendation"], found_urls=None ) # 3. crossref real_crossref_data = get_crossref_score_and_reason(request.article_body) SIMILARITY_THRESHOLD = 0.7 # 70% 이상 일치하는 것만 보여주기 # 유사도가 높은 순으로 정렬 sorted_urls = sorted( real_crossref_data["paired_results"], key=lambda x: x["similarity"], reverse=True ) # 임계값(THRESHOLD) 이상의 URL만 필터링 filtered_urls = [ FoundURL(url=item["url"], similarity=item["similarity"]) for item in sorted_urls if item["similarity"] >= SIMILARITY_THRESHOLD ] # 최종 CrossrefScore 객체 생성 (필터링된 URL 포함) final_crossref = ScoreBreakdown( score=real_crossref_data["score"], reason=real_crossref_data["reason"], recommendation=real_crossref_data["recommendation"], found_urls=filtered_urls ) # 최종 위험도 계산 w_aggro = 0.2 w_mismatch = 0.5 w_crossref = 0.3 final_score = (real_aggro.score * w_aggro) + \ (real_mismatch.score * w_mismatch) + \ (final_crossref.score * w_crossref) final_level = "안전" if final_score >= 0.5: final_level = "위험" # ------------------------------------------------ # 🛑 [핵심 추가] DB 저장 로직 시작 # ------------------------------------------------ # 1. 'articles' 테이블에 기사 저장 new_article = Article( title=request.article_title, body=request.article_body, source="Swagger UI Test" # 테스트용 출처 입력 ) db.add(new_article) # 2. article_id를 얻기 위해 Flush (아직 commit은 하지 않음) db.flush() # 3. 'analysis_results' 테이블에 분석 결과 저장 new_result = AnalysisResult( article_id=new_article.article_id, # 외래 키 연결 aggro_score=real_aggro.score, mismatch_score=real_mismatch.score, crossref_score=final_crossref.score, final_risk=final_score, ) db.add(new_result) # 4. 모든 변경 사항을 DB에 영구 저장 (트랜잭션 완료) db.commit() # ------------------------------------------------ # API 명세서(AnalysisResponse) 형식에 맞춰서 반환 return AnalysisResponse( final_risk_score=round(final_score, 4), # 소수점 4자리로 반올림 final_risk_level=final_level, breakdown={ "aggro_score": real_aggro, "mismatch_score": real_mismatch, "crossref_score": final_crossref } ) @app.get("/") def read_root(): return {"message": "AI 기사 분석 서버"}