from fastapi import FastAPI, Depends from fastapi.responses import JSONResponse from pydantic import BaseModel from typing import List, Optional from backend.recommendWord import recommendWord from fastapi.middleware.cors import CORSMiddleware from sentence_transformers import SentenceTransformer from keybert import KeyBERT from kiwipiepy import Kiwi import pandas as pd import faiss from transformers import AutoTokenizer, AutoModelForMaskedLM from backend.ref import refRecommend from backend.spellchecker import check from backend.auth.routes import router as auth_router from backend.auth.routes import get_current_user from backend.post_router import router as post_router import os import requests FAISS_URL = ( "https://huggingface.co/datasets/uuuy5615/my_index/resolve/main/faiss_index.idx" ) CSV_URL = "https://huggingface.co/datasets/uuuy5615/my_index/resolve/main/kci.csv" FAISS_PATH = "faiss_index.idx" CSV_PATH = "kci.csv" def mask_by_position(sentence: str, start: int, end: int) -> str: return sentence[:start] + "[MASK]" + sentence[end:] if not os.path.exists(FAISS_PATH): print("FAISS 파일 다운로드 중...") r = requests.get(FAISS_URL) r.raise_for_status() # 실패 시 에러 발생 with open(FAISS_PATH, "wb") as f: f.write(r.content) print("FAISS 다운로드 완료!") # CSV 파일 다운로드 if not os.path.exists(CSV_PATH): print("CSV 파일 다운로드 중...") r = requests.get(CSV_URL) r.raise_for_status() with open(CSV_PATH, "wb") as f: f.write(r.content) print("CSV 다운로드 완료!") # refrec refModel = SentenceTransformer("jhgan/ko-sbert-nli") kw_model = KeyBERT(refModel) kiwi = Kiwi() df = pd.read_csv("kci.csv", low_memory=False) index = faiss.read_index("faiss_index.idx") # wordrec tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large") wordModel = AutoModelForMaskedLM.from_pretrained("klue/roberta-large") app = FastAPI() app.include_router(auth_router, prefix="/auth", tags=["auth"]) app.include_router(post_router, prefix="/post", tags=["post"]) app.add_middleware( CORSMiddleware, allow_origins=[ "http://127.0.0.1:5173", "https://geulditbul.vercel.app", ], # React 앱 주소 allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class SpellCheckRequest(BaseModel): text: str class Correction(BaseModel): error: str checked: str position: Optional[int] length: int errortype: int class SpellCheckResponse(BaseModel): flag: int original_text: str checked_text: str corrections: List[Correction] time: float @app.post("/api/spellcheck", response_model=SpellCheckResponse) def api_spellcheck(req: SpellCheckRequest, _: dict = Depends(get_current_user)): # check는 dict를 반환 return check(req.text) @app.get("/model/WordRec") async def runWordRec( user_sentence: str, MaskWord: str, start: int, end: int, _: dict = Depends(get_current_user), ): sentence = mask_by_position(user_sentence, start - 1, end - 1) rec_words = recommendWord(sentence, MaskWord, tokenizer, wordModel) # result = {"model_name":model_name, "masked_word": MaskWord, "rec_word" : rec_word} # result1, result2, result3 = [rec_word[i:i+3] for i in range(0, len(rec_word), 3)] data = {"Model": "WordRec", "masked_word": MaskWord, "rec_result": rec_words} return JSONResponse(content=data) @app.get("/model/RefRec") async def runRefRec(text: str, _: dict = Depends(get_current_user)): name, link = refRecommend(refModel, kw_model, kiwi, text, df, index) data = {"Model": "RefRec", "name_result": name, "link_result": link} return JSONResponse(content=data)