ishmeet-yo commited on
Commit
b34d6f6
·
verified ·
1 Parent(s): a21be97

Upload 7 files

Browse files
Files changed (7) hide show
  1. Dockerfile +10 -0
  2. README.md +7 -11
  3. app/data/harry_potter_1.txt +0 -0
  4. app/main.py +32 -0
  5. app/rag.py +158 -0
  6. requirements.txt +6 -0
  7. templates/index.html +49 -0
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,7 @@
1
- ---
2
- title: ISH Harry Potter Rag
3
- emoji: 👀
4
- colorFrom: purple
5
- colorTo: purple
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # Harry Potter RAG
2
+
3
+ Semantic Retrieval-Augmented Generation system using FastAPI and Sentence Transformers.
4
+
5
+ ## Run locally
6
+ ```bash
7
+ uvicorn app.main:app --reload
 
 
 
 
app/data/harry_potter_1.txt ADDED
The diff for this file is too large to render. See raw diff
 
app/main.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.templating import Jinja2Templates
3
+ from fastapi.staticfiles import StaticFiles
4
+
5
+ from rag import load_data, retrieve_chunks
6
+
7
+ app = FastAPI()
8
+
9
+ templates = Jinja2Templates(directory="templates")
10
+ app.mount("/static", StaticFiles(directory="static"), name="static")
11
+
12
+ chunks, heads = load_data()
13
+
14
+ @app.get("/")
15
+ def home(request: Request):
16
+ return templates.TemplateResponse(
17
+ "index.html",
18
+ {"request": request}
19
+ )
20
+
21
+ @app.post("/search")
22
+ async def search(request: Request):
23
+ body = await request.json()
24
+ query = body["query"]
25
+
26
+ retrieved = retrieve_chunks(query, chunks, heads)
27
+ answer = "\n\n".join(retrieved[:2])
28
+
29
+ return {
30
+ "answer": answer,
31
+ "sources": retrieved
32
+ }
app/rag.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import hashlib
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from sklearn.feature_extraction.text import TfidfVectorizer
7
+ from sklearn.preprocessing import normalize
8
+
9
+ CACHE_DIR = "app/cache"
10
+ DATA_DIR = "app/data"
11
+ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
12
+
13
+
14
+ def compute_hash(files):
15
+ h = hashlib.md5()
16
+ for f in files:
17
+ with open(f, "rb") as fp:
18
+ h.update(fp.read())
19
+ return h.hexdigest()
20
+
21
+
22
+ def load_documents():
23
+ files = [
24
+ os.path.join(DATA_DIR, f)
25
+ for f in os.listdir(DATA_DIR)
26
+ if f.endswith(".txt")
27
+ ]
28
+
29
+ texts = []
30
+ for f in files:
31
+ with open(f, encoding="utf-8", errors="ignore") as fp:
32
+ texts.append(fp.read())
33
+
34
+ return texts, files
35
+
36
+
37
+ def chunk_text(text, size=500, overlap=100):
38
+ words = text.split()
39
+ chunks = []
40
+ i = 0
41
+
42
+ while i < len(words):
43
+ chunk = words[i:i+size]
44
+ chunks.append(" ".join(chunk))
45
+ i += size - overlap
46
+
47
+ return chunks
48
+
49
+ def chunk_documents(texts):
50
+ chunks = []
51
+ for t in texts:
52
+ chunks.extend(chunk_text(t))
53
+ return chunks
54
+
55
+
56
+ def build_embeddings(chunks):
57
+ model = SentenceTransformer(MODEL_NAME)
58
+
59
+ semantic = normalize(model.encode(chunks))
60
+ narrative = normalize(model.encode(
61
+ ["Story context: " + c for c in chunks]
62
+ ))
63
+ entity = normalize(model.encode(chunks))
64
+
65
+ tfidf = TfidfVectorizer()
66
+ tfidf_matrix = tfidf.fit_transform(chunks)
67
+
68
+ return {
69
+ "semantic": semantic,
70
+ "narrative": narrative,
71
+ "entity": entity,
72
+ "tfidf": tfidf,
73
+ "tfidf_matrix": tfidf_matrix,
74
+ "model": model
75
+ }
76
+
77
+
78
+ def save_cache(chunks, heads, dataset_hash):
79
+ os.makedirs(CACHE_DIR, exist_ok=True)
80
+
81
+ np.save(f"{CACHE_DIR}/semantic.npy", heads["semantic"])
82
+ np.save(f"{CACHE_DIR}/narrative.npy", heads["narrative"])
83
+ np.save(f"{CACHE_DIR}/entity.npy", heads["entity"])
84
+
85
+ with open(f"{CACHE_DIR}/chunks.pkl", "wb") as f:
86
+ pickle.dump(chunks, f)
87
+
88
+ with open(f"{CACHE_DIR}/tfidf.pkl", "wb") as f:
89
+ pickle.dump(heads["tfidf"], f)
90
+
91
+ with open(f"{CACHE_DIR}/tfidf_matrix.pkl", "wb") as f:
92
+ pickle.dump(heads["tfidf_matrix"], f)
93
+
94
+ with open(f"{CACHE_DIR}/hash.txt", "w") as f:
95
+ f.write(dataset_hash)
96
+
97
+ def load_cache():
98
+ with open(f"{CACHE_DIR}/chunks.pkl", "rb") as f:
99
+ chunks = pickle.load(f)
100
+
101
+ heads = {
102
+ "semantic": np.load(f"{CACHE_DIR}/semantic.npy"),
103
+ "narrative": np.load(f"{CACHE_DIR}/narrative.npy"),
104
+ "entity": np.load(f"{CACHE_DIR}/entity.npy")
105
+ }
106
+
107
+ with open(f"{CACHE_DIR}/tfidf.pkl", "rb") as f:
108
+ heads["tfidf"] = pickle.load(f)
109
+
110
+ with open(f"{CACHE_DIR}/tfidf_matrix.pkl", "rb") as f:
111
+ heads["tfidf_matrix"] = pickle.load(f)
112
+
113
+ heads["model"] = SentenceTransformer(MODEL_NAME)
114
+ return chunks, heads
115
+
116
+ def load_data():
117
+ texts, files = load_documents()
118
+ chunks = chunk_documents(texts)
119
+ dataset_hash = compute_hash(files)
120
+
121
+ hash_path = f"{CACHE_DIR}/hash.txt"
122
+
123
+ if os.path.exists(hash_path):
124
+ with open(hash_path) as f:
125
+ cached_hash = f.read().strip()
126
+ else:
127
+ cached_hash = None
128
+
129
+ if cached_hash == dataset_hash:
130
+ print("Loading embeddings from cache")
131
+ return load_cache()
132
+
133
+ print("Building embeddings")
134
+ heads = build_embeddings(chunks)
135
+ save_cache(chunks, heads, dataset_hash)
136
+ return chunks, heads
137
+
138
+
139
+ def retrieve_chunks(query, chunks, heads, k=5):
140
+ model = heads["model"]
141
+
142
+ q_sem = normalize(model.encode([query]))
143
+ q_nav = normalize(model.encode(["Story question: " + query]))
144
+
145
+ sem_score = heads["semantic"] @ q_sem.T
146
+ nav_score = heads["narrative"] @ q_nav.T
147
+
148
+ q_tfidf = heads["tfidf"].transform([query])
149
+ key_score = heads["tfidf_matrix"] @ q_tfidf.T
150
+
151
+ final = (
152
+ 0.45 * sem_score +
153
+ 0.35 * nav_score +
154
+ 0.20 * key_score.toarray()
155
+ )
156
+
157
+ idx = np.argsort(final.flatten())[::-1][:k]
158
+ return [chunks[i] for i in idx]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ sentence-transformers
4
+ numpy
5
+ scikit-learn
6
+ jinja2
templates/index.html ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>Harry Potter RAG</title>
5
+ <script src="https://cdn.tailwindcss.com"></script>
6
+ </head>
7
+
8
+ <body class="bg-zinc-50">
9
+ <div class="max-w-3xl mx-auto py-20">
10
+ <h1 class="text-4xl font-semibold mb-6">
11
+ Harry Potter Semantic Search
12
+ </h1>
13
+
14
+ <input
15
+ id="query"
16
+ class="w-full p-4 rounded-xl shadow"
17
+ placeholder="Ask something..."
18
+ />
19
+
20
+ <button
21
+ onclick="search()"
22
+ class="mt-4 px-6 py-3 bg-black text-white rounded-xl"
23
+ >
24
+ Search
25
+ </button>
26
+
27
+ <div id="answer" class="mt-8"></div>
28
+ </div>
29
+
30
+ <script>
31
+ async function search() {
32
+ const q = document.getElementById("query").value;
33
+
34
+ const res = await fetch("/search", {
35
+ method: "POST",
36
+ headers: {"Content-Type": "application/json"},
37
+ body: JSON.stringify({query: q})
38
+ });
39
+
40
+ const data = await res.json();
41
+
42
+ document.getElementById("answer").innerHTML =
43
+ `<div class="p-6 bg-white rounded-xl shadow">
44
+ ${data.answer}
45
+ </div>`;
46
+ }
47
+ </script>
48
+ </body>
49
+ </html>