Initial codes commit
Browse files- Dockerfile +12 -0
- api/endpoints.py +11 -0
- api/schemas.py +9 -0
- app.py +29 -0
- config.py +25 -0
- db/initializer.py +32 -0
- modules/corpus.py +61 -0
- modules/embedder.py +17 -0
- modules/reranker.py +39 -0
- modules/retriever.py +20 -0
- modules/utils.py +12 -0
- requirements.txt +7 -0
- service/search.py +11 -0
- templates/index.html +18 -0
Dockerfile
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# requirements ์ค์น
|
| 6 |
+
COPY requirements.txt .
|
| 7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
+
|
| 9 |
+
# ์ฑ ์ฝ๋ ๋ณต์ฌ
|
| 10 |
+
COPY . .
|
| 11 |
+
|
| 12 |
+
CMD ["sh", "-c", "uvicorn rag.app:app --host 0.0.0.0 --port ${PORT}"]
|
api/endpoints.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/api/endpoints.py
|
| 2 |
+
from fastapi import APIRouter
|
| 3 |
+
from api.schemas import SearchRequest, SearchResponse
|
| 4 |
+
from rag.service.search import search
|
| 5 |
+
|
| 6 |
+
router = APIRouter()
|
| 7 |
+
|
| 8 |
+
@router.post("/search", response_model=SearchResponse)
|
| 9 |
+
def search_context(req: SearchRequest):
|
| 10 |
+
results = search(req.query)
|
| 11 |
+
return {"results": results}
|
api/schemas.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/api/schemas.py
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
|
| 5 |
+
class SearchRequest(BaseModel):
|
| 6 |
+
query: str
|
| 7 |
+
|
| 8 |
+
class SearchResponse(BaseModel):
|
| 9 |
+
results: List[Dict[str, Any]]
|
app.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/app.py
|
| 2 |
+
from contextlib import asynccontextmanager
|
| 3 |
+
from fastapi import FastAPI, Request, Form
|
| 4 |
+
from fastapi.responses import HTMLResponse
|
| 5 |
+
from fastapi.templating import Jinja2Templates
|
| 6 |
+
from api.endpoints import router
|
| 7 |
+
from rag.db.initializer import initialize
|
| 8 |
+
from rag.service.search import search
|
| 9 |
+
|
| 10 |
+
templates = Jinja2Templates(directory="templates")
|
| 11 |
+
|
| 12 |
+
@asynccontextmanager
|
| 13 |
+
async def lifespan(_app: FastAPI):
|
| 14 |
+
initialize()
|
| 15 |
+
yield
|
| 16 |
+
|
| 17 |
+
app = FastAPI(lifespan=lifespan)
|
| 18 |
+
app.include_router(router)
|
| 19 |
+
|
| 20 |
+
@app.get("/", response_class=HTMLResponse)
|
| 21 |
+
def index(request: Request):
|
| 22 |
+
return templates.TemplateResponse("index.html", {"request": request, "results": None})
|
| 23 |
+
|
| 24 |
+
@app.post("/demo", response_class=HTMLResponse)
|
| 25 |
+
def demo(request: Request, query: str = Form(...)):
|
| 26 |
+
results = search(query)
|
| 27 |
+
return templates.TemplateResponse("index.html", {"request": request, "results": results, "query": query})
|
| 28 |
+
|
| 29 |
+
|
config.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/config.py
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# Huggingface Hub token
|
| 5 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 6 |
+
|
| 7 |
+
# HF datasets repo info
|
| 8 |
+
HF_REPO_ID = os.getenv("HF_REPO_ID", "m97j/pls-datasets")
|
| 9 |
+
HF_INDEX_FILE = os.getenv("HF_INDEX_FILE", "faiss_index_flat.faiss")
|
| 10 |
+
HF_IDS_FILE = os.getenv("HF_IDS_FILE", "vector_ids.npy")
|
| 11 |
+
|
| 12 |
+
# Corpus dataset info
|
| 13 |
+
HF_CORPUS_REPO = os.getenv("HF_CORPUS_REPO", "HuggingFaceFW/finewiki")
|
| 14 |
+
HF_CORPUS_SUBSET = os.getenv("HF_CORPUS_SUBSET", "ko")
|
| 15 |
+
HF_CORPUS_SPLIT = os.getenv("HF_CORPUS_SPLIT", "train")
|
| 16 |
+
|
| 17 |
+
# Local paths
|
| 18 |
+
MARKER_DIR = os.getenv("MARKER_DIR", "rag/state")
|
| 19 |
+
CORPUS_READY_MARK = os.path.join(MARKER_DIR, ".corpus_ready")
|
| 20 |
+
|
| 21 |
+
# Embedding / LLM model
|
| 22 |
+
EMBED_MODEL = os.getenv("EMBED_MODEL", "intfloat/multilingual-e5-large")
|
| 23 |
+
TOP_K = int(os.getenv("TOP_K", "5"))
|
| 24 |
+
|
| 25 |
+
|
db/initializer.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/db/initializer.py
|
| 2 |
+
import faiss
|
| 3 |
+
import numpy as np
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
from config import HF_REPO_ID, HF_INDEX_FILE, HF_IDS_FILE
|
| 6 |
+
from modules.utils import ensure_dir
|
| 7 |
+
from modules.retriever import set_index
|
| 8 |
+
from modules import corpus
|
| 9 |
+
|
| 10 |
+
_vector_ids = None
|
| 11 |
+
|
| 12 |
+
def _load_index_in_memory():
|
| 13 |
+
"""HF Hub์์ ์ธ๋ฑ์ค/ID ๋งคํ์ ๋ฐ์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋"""
|
| 14 |
+
index_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_INDEX_FILE)
|
| 15 |
+
ids_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_IDS_FILE)
|
| 16 |
+
index = faiss.read_index(index_path)
|
| 17 |
+
set_index(index)
|
| 18 |
+
global _vector_ids
|
| 19 |
+
_vector_ids = np.load(ids_path, allow_pickle=True)
|
| 20 |
+
|
| 21 |
+
def get_vector_ids():
|
| 22 |
+
global _vector_ids
|
| 23 |
+
return _vector_ids
|
| 24 |
+
|
| 25 |
+
def initialize():
|
| 26 |
+
# 1) ์ฝํผ์ค ์ค๋น (์ต์ด 1ํ๋ง ๋ค์ด๋ก๋)
|
| 27 |
+
corpus.prepare_corpus()
|
| 28 |
+
# 2) ์ธ๋ฑ์ค/ID ๋งคํ ๋ฉ๋ชจ๋ฆฌ ๋ก๋
|
| 29 |
+
_load_index_in_memory()
|
| 30 |
+
|
| 31 |
+
def force_update():
|
| 32 |
+
_load_index_in_memory()
|
modules/corpus.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/modules/corpus_store.py
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
from datasets import load_dataset, DatasetDict, Dataset
|
| 4 |
+
from config import HF_CORPUS_REPO, HF_CORPUS_SUBSET, HF_CORPUS_SPLIT, MARKER_DIR, CORPUS_READY_MARK
|
| 5 |
+
from modules.utils import ensure_dir, exists, touch
|
| 6 |
+
|
| 7 |
+
_datasets: Dict[str, Dataset] = {}
|
| 8 |
+
|
| 9 |
+
def prepare_corpus():
|
| 10 |
+
"""
|
| 11 |
+
์ต์ด 1ํ๋ง parquet split์ ๋ก์ปฌ์ ๋ค์ด๋ก๋.
|
| 12 |
+
์ดํ์๋ ๋ก์ปฌ ์บ์ ์ฌ์ฉ.
|
| 13 |
+
"""
|
| 14 |
+
ensure_dir(MARKER_DIR)
|
| 15 |
+
if exists(CORPUS_READY_MARK):
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
subsets = HF_CORPUS_SUBSET.split(",") # "ko,en" โ ["ko","en"]
|
| 19 |
+
for subset in subsets:
|
| 20 |
+
load_dataset(HF_CORPUS_REPO, subset.strip(), split=HF_CORPUS_SPLIT)
|
| 21 |
+
touch(CORPUS_READY_MARK)
|
| 22 |
+
|
| 23 |
+
def _get_datasets() -> Dict[str, Dataset]:
|
| 24 |
+
global _datasets
|
| 25 |
+
if not _datasets:
|
| 26 |
+
subsets = HF_CORPUS_SUBSET.split(",")
|
| 27 |
+
for subset in subsets:
|
| 28 |
+
_datasets[subset.strip()] = load_dataset(
|
| 29 |
+
HF_CORPUS_REPO, subset.strip(), split=HF_CORPUS_SPLIT
|
| 30 |
+
)
|
| 31 |
+
return _datasets
|
| 32 |
+
|
| 33 |
+
def fetch_contexts_by_ids(ids: List[int]) -> List[Dict[str, Any]]:
|
| 34 |
+
if not ids:
|
| 35 |
+
return []
|
| 36 |
+
|
| 37 |
+
datasets = _get_datasets()
|
| 38 |
+
id_set = set(ids)
|
| 39 |
+
results: List[Dict[str, Any]] = []
|
| 40 |
+
|
| 41 |
+
# ๋ชจ๋ subset์ ์ํํ๋ฉฐ page_id ๋งค์นญ
|
| 42 |
+
for subset, ds in datasets.items():
|
| 43 |
+
# filter๋ฅผ ์ฌ์ฉํ๋ฉด ์ ์ฒด ์ํ๋ณด๋ค ๋น ๋ฆ (๋ณ๋ ฌ ์ต์ ํ)
|
| 44 |
+
rows = ds.filter(lambda r: r["page_id"] in id_set)
|
| 45 |
+
|
| 46 |
+
id_to_row = {r["page_id"]: r for r in rows}
|
| 47 |
+
for i in ids:
|
| 48 |
+
r = id_to_row.get(i)
|
| 49 |
+
if r:
|
| 50 |
+
results.append({
|
| 51 |
+
"id": r["page_id"],
|
| 52 |
+
"title": r.get("title", ""),
|
| 53 |
+
"text": r.get("wikitext", ""),
|
| 54 |
+
"url": r.get("url", ""),
|
| 55 |
+
"metadata": {
|
| 56 |
+
"date_modified": r.get("date_modified", ""),
|
| 57 |
+
"in_language": r.get("in_language", ""),
|
| 58 |
+
"wikidata_id": r.get("wikidata_id", "")
|
| 59 |
+
}
|
| 60 |
+
})
|
| 61 |
+
return results
|
modules/embedder.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/modules/embedder.py
|
| 2 |
+
import math
|
| 3 |
+
from typing import List
|
| 4 |
+
from huggingface_hub import InferenceClient
|
| 5 |
+
from config import EMBED_MODEL, HF_TOKEN
|
| 6 |
+
|
| 7 |
+
_client = InferenceClient(model=EMBED_MODEL, token=HF_TOKEN)
|
| 8 |
+
|
| 9 |
+
def _l2_normalize(vec: List[float]) -> List[float]:
|
| 10 |
+
norm = math.sqrt(sum(x * x for x in vec)) or 1.0
|
| 11 |
+
return [x / norm for x in vec]
|
| 12 |
+
|
| 13 |
+
def get_embedding(text: str) -> List[float]:
|
| 14 |
+
# feature_extraction์ ํญ์ 2์ฐจ์ ๋ฐฐ์ด ๋ฐํ: [batch_size, embedding_dim]
|
| 15 |
+
embedding_2d = _client.feature_extraction(text)
|
| 16 |
+
vec = embedding_2d[0] # ์ฒซ ๋ฒ์งธ ํ์ด ์
๋ ฅ ๋ฌธ์ฅ์ ๋ฒกํฐ
|
| 17 |
+
return _l2_normalize(vec)
|
modules/reranker.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/modules/reranker.py
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
from huggingface_hub import InferenceClient
|
| 5 |
+
|
| 6 |
+
# ํ๊ฒฝ๋ณ์์์ ๋ชจ๋ธ๋ช
๊ณผ ํ ํฐ ๋ถ๋ฌ์ค๊ธฐ
|
| 7 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 8 |
+
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-large")
|
| 9 |
+
|
| 10 |
+
_client = InferenceClient(model=RERANK_MODEL, token=HF_TOKEN)
|
| 11 |
+
|
| 12 |
+
# threshold ๊ฐ์ ํ๊ฒฝ๋ณ์๋ config์์ ๊ด๋ฆฌ ๊ฐ๋ฅ
|
| 13 |
+
THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
|
| 14 |
+
|
| 15 |
+
def rerank(query: str, contexts: List[Dict]) -> List[Dict]:
|
| 16 |
+
"""
|
| 17 |
+
contexts: [{"id": ..., "text": ...}, ...]
|
| 18 |
+
๋ฐํ: threshold ์ด์ ์ ์๋ง ํฌํจ๋ reranked contexts
|
| 19 |
+
"""
|
| 20 |
+
if not contexts:
|
| 21 |
+
return []
|
| 22 |
+
|
| 23 |
+
# reranker ์
๋ ฅ: (query, passage) ์ ๋ฆฌ์คํธ
|
| 24 |
+
pairs = [(query, ctx["text"]) for ctx in contexts]
|
| 25 |
+
|
| 26 |
+
# Inference API ํธ์ถ โ ๊ฐ ์์ ๋ํ ์ ์ ๋ฐํ
|
| 27 |
+
scores = _client.rerank(inputs=pairs)
|
| 28 |
+
|
| 29 |
+
# scores๋ [{"score": float}, ...] ํํ
|
| 30 |
+
for ctx, sc in zip(contexts, scores):
|
| 31 |
+
ctx["score"] = sc["score"]
|
| 32 |
+
|
| 33 |
+
# ์ ์ ๋ด๋ฆผ์ฐจ์ ์ ๋ ฌ
|
| 34 |
+
reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)
|
| 35 |
+
|
| 36 |
+
# threshold ์ด์๋ง ํํฐ๋ง
|
| 37 |
+
reranked = [c for c in reranked if c["score"] >= THRESHOLD]
|
| 38 |
+
|
| 39 |
+
return reranked
|
modules/retriever.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/modules/retriever.py
|
| 2 |
+
import numpy as np
|
| 3 |
+
from config import TOP_K
|
| 4 |
+
|
| 5 |
+
_index = None # in-memory FAISS index
|
| 6 |
+
|
| 7 |
+
def set_index(index_obj):
|
| 8 |
+
global _index
|
| 9 |
+
_index = index_obj
|
| 10 |
+
|
| 11 |
+
def has_index() -> bool:
|
| 12 |
+
return _index is not None
|
| 13 |
+
|
| 14 |
+
def retrieve_ids(query_embedding: list[float]) -> list[int]:
|
| 15 |
+
if _index is None:
|
| 16 |
+
raise RuntimeError("FAISS index is not loaded in memory.")
|
| 17 |
+
q = np.array([query_embedding], dtype="float32")
|
| 18 |
+
_, idx = _index.search(q, TOP_K)
|
| 19 |
+
return [int(i) for i in idx[0]]
|
| 20 |
+
|
modules/utils.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/modules/utils.py
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def ensure_dir(path: str):
|
| 5 |
+
os.makedirs(path, exist_ok=True)
|
| 6 |
+
|
| 7 |
+
def touch(path: str):
|
| 8 |
+
with open(path, "a", encoding="utf-8") as f:
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
def exists(path: str) -> bool:
|
| 12 |
+
return os.path.exists(path)
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.0
|
| 2 |
+
uvicorn[standard]==0.30.0
|
| 3 |
+
|
| 4 |
+
huggingface_hub==0.24.6
|
| 5 |
+
faiss-cpu==1.8.0
|
| 6 |
+
numpy==1.26.4
|
| 7 |
+
pydantic==2.11.5
|
service/search.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modules.embedder import get_embedding
|
| 2 |
+
from modules.retriever import retrieve_ids
|
| 3 |
+
from modules.corpus import fetch_contexts_by_ids
|
| 4 |
+
from modules.reranker import rerank
|
| 5 |
+
|
| 6 |
+
def search(query: str) -> list[dict]:
|
| 7 |
+
embedding = get_embedding(query)
|
| 8 |
+
ids = retrieve_ids(embedding)
|
| 9 |
+
contexts = fetch_contexts_by_ids(ids)
|
| 10 |
+
reranked = rerank(query, contexts)
|
| 11 |
+
return reranked
|
templates/index.html
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html>
|
| 3 |
+
<body>
|
| 4 |
+
<h1>RAG Search Demo</h1>
|
| 5 |
+
<form method="post" action="/demo">
|
| 6 |
+
<input type="text" name="query" placeholder="Enter query" required>
|
| 7 |
+
<button type="submit">Search</button>
|
| 8 |
+
</form>
|
| 9 |
+
{% if results %}
|
| 10 |
+
<h2>Results for: {{ query }}</h2>
|
| 11 |
+
<ul>
|
| 12 |
+
{% for r in results %}
|
| 13 |
+
<li>{{ r }}</li>
|
| 14 |
+
{% endfor %}
|
| 15 |
+
</ul>
|
| 16 |
+
{% endif %}
|
| 17 |
+
</body>
|
| 18 |
+
</html>
|