m97j commited on
Commit
4fdc679
ยท
1 Parent(s): dcb92f1

Initial codes commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # requirements ์„ค์น˜
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # ์•ฑ ์ฝ”๋“œ ๋ณต์‚ฌ
10
+ COPY . .
11
+
12
+ CMD ["sh", "-c", "uvicorn rag.app:app --host 0.0.0.0 --port ${PORT}"]
api/endpoints.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/api/endpoints.py
2
+ from fastapi import APIRouter
3
+ from api.schemas import SearchRequest, SearchResponse
4
+ from rag.service.search import search
5
+
6
+ router = APIRouter()
7
+
8
+ @router.post("/search", response_model=SearchResponse)
9
+ def search_context(req: SearchRequest):
10
+ results = search(req.query)
11
+ return {"results": results}
api/schemas.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # rag/api/schemas.py
2
+ from pydantic import BaseModel
3
+ from typing import List, Dict, Any
4
+
5
+ class SearchRequest(BaseModel):
6
+ query: str
7
+
8
+ class SearchResponse(BaseModel):
9
+ results: List[Dict[str, Any]]
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/app.py
2
+ from contextlib import asynccontextmanager
3
+ from fastapi import FastAPI, Request, Form
4
+ from fastapi.responses import HTMLResponse
5
+ from fastapi.templating import Jinja2Templates
6
+ from api.endpoints import router
7
+ from rag.db.initializer import initialize
8
+ from rag.service.search import search
9
+
10
+ templates = Jinja2Templates(directory="templates")
11
+
12
+ @asynccontextmanager
13
+ async def lifespan(_app: FastAPI):
14
+ initialize()
15
+ yield
16
+
17
+ app = FastAPI(lifespan=lifespan)
18
+ app.include_router(router)
19
+
20
+ @app.get("/", response_class=HTMLResponse)
21
+ def index(request: Request):
22
+ return templates.TemplateResponse("index.html", {"request": request, "results": None})
23
+
24
+ @app.post("/demo", response_class=HTMLResponse)
25
+ def demo(request: Request, query: str = Form(...)):
26
+ results = search(query)
27
+ return templates.TemplateResponse("index.html", {"request": request, "results": results, "query": query})
28
+
29
+
config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/config.py
2
+ import os
3
+
4
+ # Huggingface Hub token
5
+ HF_TOKEN = os.getenv("HF_TOKEN")
6
+
7
+ # HF datasets repo info
8
+ HF_REPO_ID = os.getenv("HF_REPO_ID", "m97j/pls-datasets")
9
+ HF_INDEX_FILE = os.getenv("HF_INDEX_FILE", "faiss_index_flat.faiss")
10
+ HF_IDS_FILE = os.getenv("HF_IDS_FILE", "vector_ids.npy")
11
+
12
+ # Corpus dataset info
13
+ HF_CORPUS_REPO = os.getenv("HF_CORPUS_REPO", "HuggingFaceFW/finewiki")
14
+ HF_CORPUS_SUBSET = os.getenv("HF_CORPUS_SUBSET", "ko")
15
+ HF_CORPUS_SPLIT = os.getenv("HF_CORPUS_SPLIT", "train")
16
+
17
+ # Local paths
18
+ MARKER_DIR = os.getenv("MARKER_DIR", "rag/state")
19
+ CORPUS_READY_MARK = os.path.join(MARKER_DIR, ".corpus_ready")
20
+
21
+ # Embedding / LLM model
22
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "intfloat/multilingual-e5-large")
23
+ TOP_K = int(os.getenv("TOP_K", "5"))
24
+
25
+
db/initializer.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/db/initializer.py
2
+ import faiss
3
+ import numpy as np
4
+ from huggingface_hub import hf_hub_download
5
+ from config import HF_REPO_ID, HF_INDEX_FILE, HF_IDS_FILE
6
+ from modules.utils import ensure_dir
7
+ from modules.retriever import set_index
8
+ from modules import corpus
9
+
10
+ _vector_ids = None
11
+
12
+ def _load_index_in_memory():
13
+ """HF Hub์—์„œ ์ธ๋ฑ์Šค/ID ๋งคํ•‘์„ ๋ฐ›์•„ ๋ฉ”๋ชจ๋ฆฌ์— ๋กœ๋“œ"""
14
+ index_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_INDEX_FILE)
15
+ ids_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_IDS_FILE)
16
+ index = faiss.read_index(index_path)
17
+ set_index(index)
18
+ global _vector_ids
19
+ _vector_ids = np.load(ids_path, allow_pickle=True)
20
+
21
+ def get_vector_ids():
22
+ global _vector_ids
23
+ return _vector_ids
24
+
25
+ def initialize():
26
+ # 1) ์ฝ”ํผ์Šค ์ค€๋น„ (์ตœ์ดˆ 1ํšŒ๋งŒ ๋‹ค์šด๋กœ๋“œ)
27
+ corpus.prepare_corpus()
28
+ # 2) ์ธ๋ฑ์Šค/ID ๋งคํ•‘ ๋ฉ”๋ชจ๋ฆฌ ๋กœ๋“œ
29
+ _load_index_in_memory()
30
+
31
+ def force_update():
32
+ _load_index_in_memory()
modules/corpus.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/modules/corpus_store.py
2
+ from typing import List, Dict, Any
3
+ from datasets import load_dataset, DatasetDict, Dataset
4
+ from config import HF_CORPUS_REPO, HF_CORPUS_SUBSET, HF_CORPUS_SPLIT, MARKER_DIR, CORPUS_READY_MARK
5
+ from modules.utils import ensure_dir, exists, touch
6
+
7
+ _datasets: Dict[str, Dataset] = {}
8
+
9
+ def prepare_corpus():
10
+ """
11
+ ์ตœ์ดˆ 1ํšŒ๋งŒ parquet split์„ ๋กœ์ปฌ์— ๋‹ค์šด๋กœ๋“œ.
12
+ ์ดํ›„์—๋Š” ๋กœ์ปฌ ์บ์‹œ ์‚ฌ์šฉ.
13
+ """
14
+ ensure_dir(MARKER_DIR)
15
+ if exists(CORPUS_READY_MARK):
16
+ return
17
+
18
+ subsets = HF_CORPUS_SUBSET.split(",") # "ko,en" โ†’ ["ko","en"]
19
+ for subset in subsets:
20
+ load_dataset(HF_CORPUS_REPO, subset.strip(), split=HF_CORPUS_SPLIT)
21
+ touch(CORPUS_READY_MARK)
22
+
23
+ def _get_datasets() -> Dict[str, Dataset]:
24
+ global _datasets
25
+ if not _datasets:
26
+ subsets = HF_CORPUS_SUBSET.split(",")
27
+ for subset in subsets:
28
+ _datasets[subset.strip()] = load_dataset(
29
+ HF_CORPUS_REPO, subset.strip(), split=HF_CORPUS_SPLIT
30
+ )
31
+ return _datasets
32
+
33
+ def fetch_contexts_by_ids(ids: List[int]) -> List[Dict[str, Any]]:
34
+ if not ids:
35
+ return []
36
+
37
+ datasets = _get_datasets()
38
+ id_set = set(ids)
39
+ results: List[Dict[str, Any]] = []
40
+
41
+ # ๋ชจ๋“  subset์„ ์ˆœํšŒํ•˜๋ฉฐ page_id ๋งค์นญ
42
+ for subset, ds in datasets.items():
43
+ # filter๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ „์ฒด ์ˆœํšŒ๋ณด๋‹ค ๋น ๋ฆ„ (๋ณ‘๋ ฌ ์ตœ์ ํ™”)
44
+ rows = ds.filter(lambda r: r["page_id"] in id_set)
45
+
46
+ id_to_row = {r["page_id"]: r for r in rows}
47
+ for i in ids:
48
+ r = id_to_row.get(i)
49
+ if r:
50
+ results.append({
51
+ "id": r["page_id"],
52
+ "title": r.get("title", ""),
53
+ "text": r.get("wikitext", ""),
54
+ "url": r.get("url", ""),
55
+ "metadata": {
56
+ "date_modified": r.get("date_modified", ""),
57
+ "in_language": r.get("in_language", ""),
58
+ "wikidata_id": r.get("wikidata_id", "")
59
+ }
60
+ })
61
+ return results
modules/embedder.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/modules/embedder.py
2
+ import math
3
+ from typing import List
4
+ from huggingface_hub import InferenceClient
5
+ from config import EMBED_MODEL, HF_TOKEN
6
+
7
+ _client = InferenceClient(model=EMBED_MODEL, token=HF_TOKEN)
8
+
9
+ def _l2_normalize(vec: List[float]) -> List[float]:
10
+ norm = math.sqrt(sum(x * x for x in vec)) or 1.0
11
+ return [x / norm for x in vec]
12
+
13
+ def get_embedding(text: str) -> List[float]:
14
+ # feature_extraction์€ ํ•ญ์ƒ 2์ฐจ์› ๋ฐฐ์—ด ๋ฐ˜ํ™˜: [batch_size, embedding_dim]
15
+ embedding_2d = _client.feature_extraction(text)
16
+ vec = embedding_2d[0] # ์ฒซ ๋ฒˆ์งธ ํ–‰์ด ์ž…๋ ฅ ๋ฌธ์žฅ์˜ ๋ฒกํ„ฐ
17
+ return _l2_normalize(vec)
modules/reranker.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/modules/reranker.py
2
+ import os
3
+ from typing import List, Dict
4
+ from huggingface_hub import InferenceClient
5
+
6
+ # ํ™˜๊ฒฝ๋ณ€์ˆ˜์—์„œ ๋ชจ๋ธ๋ช…๊ณผ ํ† ํฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
7
+ HF_TOKEN = os.getenv("HF_TOKEN")
8
+ RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-large")
9
+
10
+ _client = InferenceClient(model=RERANK_MODEL, token=HF_TOKEN)
11
+
12
+ # threshold ๊ฐ’์€ ํ™˜๊ฒฝ๋ณ€์ˆ˜๋‚˜ config์—์„œ ๊ด€๋ฆฌ ๊ฐ€๋Šฅ
13
+ THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
14
+
15
+ def rerank(query: str, contexts: List[Dict]) -> List[Dict]:
16
+ """
17
+ contexts: [{"id": ..., "text": ...}, ...]
18
+ ๋ฐ˜ํ™˜: threshold ์ด์ƒ ์ ์ˆ˜๋งŒ ํฌํ•จ๋œ reranked contexts
19
+ """
20
+ if not contexts:
21
+ return []
22
+
23
+ # reranker ์ž…๋ ฅ: (query, passage) ์Œ ๋ฆฌ์ŠคํŠธ
24
+ pairs = [(query, ctx["text"]) for ctx in contexts]
25
+
26
+ # Inference API ํ˜ธ์ถœ โ†’ ๊ฐ ์Œ์— ๋Œ€ํ•œ ์ ์ˆ˜ ๋ฐ˜ํ™˜
27
+ scores = _client.rerank(inputs=pairs)
28
+
29
+ # scores๋Š” [{"score": float}, ...] ํ˜•ํƒœ
30
+ for ctx, sc in zip(contexts, scores):
31
+ ctx["score"] = sc["score"]
32
+
33
+ # ์ ์ˆ˜ ๋‚ด๋ฆผ์ฐจ์ˆœ ์ •๋ ฌ
34
+ reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)
35
+
36
+ # threshold ์ด์ƒ๋งŒ ํ•„ํ„ฐ๋ง
37
+ reranked = [c for c in reranked if c["score"] >= THRESHOLD]
38
+
39
+ return reranked
modules/retriever.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/modules/retriever.py
2
+ import numpy as np
3
+ from config import TOP_K
4
+
5
+ _index = None # in-memory FAISS index
6
+
7
+ def set_index(index_obj):
8
+ global _index
9
+ _index = index_obj
10
+
11
+ def has_index() -> bool:
12
+ return _index is not None
13
+
14
+ def retrieve_ids(query_embedding: list[float]) -> list[int]:
15
+ if _index is None:
16
+ raise RuntimeError("FAISS index is not loaded in memory.")
17
+ q = np.array([query_embedding], dtype="float32")
18
+ _, idx = _index.search(q, TOP_K)
19
+ return [int(i) for i in idx[0]]
20
+
modules/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/modules/utils.py
2
+ import os
3
+
4
+ def ensure_dir(path: str):
5
+ os.makedirs(path, exist_ok=True)
6
+
7
+ def touch(path: str):
8
+ with open(path, "a", encoding="utf-8") as f:
9
+ pass
10
+
11
+ def exists(path: str) -> bool:
12
+ return os.path.exists(path)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn[standard]==0.30.0
3
+
4
+ huggingface_hub==0.24.6
5
+ faiss-cpu==1.8.0
6
+ numpy==1.26.4
7
+ pydantic==2.11.5
service/search.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.embedder import get_embedding
2
+ from modules.retriever import retrieve_ids
3
+ from modules.corpus import fetch_contexts_by_ids
4
+ from modules.reranker import rerank
5
+
6
+ def search(query: str) -> list[dict]:
7
+ embedding = get_embedding(query)
8
+ ids = retrieve_ids(embedding)
9
+ contexts = fetch_contexts_by_ids(ids)
10
+ reranked = rerank(query, contexts)
11
+ return reranked
templates/index.html ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <body>
4
+ <h1>RAG Search Demo</h1>
5
+ <form method="post" action="/demo">
6
+ <input type="text" name="query" placeholder="Enter query" required>
7
+ <button type="submit">Search</button>
8
+ </form>
9
+ {% if results %}
10
+ <h2>Results for: {{ query }}</h2>
11
+ <ul>
12
+ {% for r in results %}
13
+ <li>{{ r }}</li>
14
+ {% endfor %}
15
+ </ul>
16
+ {% endif %}
17
+ </body>
18
+ </html>