test / backend /main.py
uuuy5615's picture
Update backend/main.py
d1617b6 verified
from fastapi import FastAPI, Depends
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import List, Optional
from backend.recommendWord import recommendWord
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
from keybert import KeyBERT
from kiwipiepy import Kiwi
import pandas as pd
import faiss
from transformers import AutoTokenizer, AutoModelForMaskedLM
from backend.ref import refRecommend
from backend.spellchecker import check
from backend.auth.routes import router as auth_router
from backend.auth.routes import get_current_user
from backend.post_router import router as post_router
import os
import requests
FAISS_URL = (
"https://huggingface.co/datasets/uuuy5615/my_index/resolve/main/faiss_index.idx"
)
CSV_URL = "https://huggingface.co/datasets/uuuy5615/my_index/resolve/main/kci.csv"
FAISS_PATH = "faiss_index.idx"
CSV_PATH = "kci.csv"
def mask_by_position(sentence: str, start: int, end: int) -> str:
return sentence[:start] + "[MASK]" + sentence[end:]
if not os.path.exists(FAISS_PATH):
print("FAISS 파일 λ‹€μš΄λ‘œλ“œ 쀑...")
r = requests.get(FAISS_URL)
r.raise_for_status() # μ‹€νŒ¨ μ‹œ μ—λŸ¬ λ°œμƒ
with open(FAISS_PATH, "wb") as f:
f.write(r.content)
print("FAISS λ‹€μš΄λ‘œλ“œ μ™„λ£Œ!")
# CSV 파일 λ‹€μš΄λ‘œλ“œ
if not os.path.exists(CSV_PATH):
print("CSV 파일 λ‹€μš΄λ‘œλ“œ 쀑...")
r = requests.get(CSV_URL)
r.raise_for_status()
with open(CSV_PATH, "wb") as f:
f.write(r.content)
print("CSV λ‹€μš΄λ‘œλ“œ μ™„λ£Œ!")
# refrec
refModel = SentenceTransformer("jhgan/ko-sbert-nli")
kw_model = KeyBERT(refModel)
kiwi = Kiwi()
df = pd.read_csv("kci.csv", low_memory=False)
index = faiss.read_index("faiss_index.idx")
# wordrec
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
wordModel = AutoModelForMaskedLM.from_pretrained("klue/roberta-large")
app = FastAPI()
app.include_router(auth_router, prefix="/auth", tags=["auth"])
app.include_router(post_router, prefix="/post", tags=["post"])
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://127.0.0.1:5173",
"https://geulditbul.vercel.app",
], # React μ•± μ£Όμ†Œ
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class SpellCheckRequest(BaseModel):
text: str
class Correction(BaseModel):
error: str
checked: str
position: Optional[int]
length: int
errortype: int
class SpellCheckResponse(BaseModel):
flag: int
original_text: str
checked_text: str
corrections: List[Correction]
time: float
@app.post("/api/spellcheck", response_model=SpellCheckResponse)
def api_spellcheck(req: SpellCheckRequest, _: dict = Depends(get_current_user)):
# checkλŠ” dictλ₯Ό λ°˜ν™˜
return check(req.text)
@app.get("/model/WordRec")
async def runWordRec(
user_sentence: str,
MaskWord: str,
start: int,
end: int,
_: dict = Depends(get_current_user),
):
sentence = mask_by_position(user_sentence, start - 1, end - 1)
rec_words = recommendWord(sentence, MaskWord, tokenizer, wordModel)
# result = {"model_name":model_name, "masked_word": MaskWord, "rec_word" : rec_word}
# result1, result2, result3 = [rec_word[i:i+3] for i in range(0, len(rec_word), 3)]
data = {"Model": "WordRec", "masked_word": MaskWord, "rec_result": rec_words}
return JSONResponse(content=data)
@app.get("/model/RefRec")
async def runRefRec(text: str, _: dict = Depends(get_current_user)):
name, link = refRecommend(refModel, kw_model, kiwi, text, df, index)
data = {"Model": "RefRec", "name_result": name, "link_result": link}
return JSONResponse(content=data)