Satyam0077 commited on
Commit
eec68b7
·
verified ·
1 Parent(s): d13c723

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +65 -65
rag_pipeline.py CHANGED
@@ -1,65 +1,65 @@
1
- import faiss
2
- import numpy as np
3
- import pandas as pd
4
- from sentence_transformers import SentenceTransformer
5
- import re
6
-
7
- class QuoteRAG:
8
- def __init__(self, model_path="models/fine_tuned_model", data_path="data/english_quotes.csv"):
9
- # Load model
10
- try:
11
- self.model = SentenceTransformer(model_path)
12
- print("Loaded fine-tuned model")
13
- except:
14
- self.model = SentenceTransformer("all-MiniLM-L6-v2")
15
- print("Loaded base model")
16
-
17
- # Load dataset
18
- self.df = pd.read_csv(data_path)
19
-
20
- # Encode all quotes
21
- self.embeddings = self.model.encode(self.df["quote"].tolist(), convert_to_numpy=True)
22
- d = self.embeddings.shape[1]
23
-
24
- # Build FAISS index
25
- self.index = faiss.IndexFlatL2(d)
26
- self.index.add(self.embeddings)
27
- print("FAISS index built with", len(self.df), "quotes")
28
-
29
- def search(self, query, top_k=5):
30
- # Encode query
31
- query_emb = self.model.encode([query], convert_to_numpy=True)
32
- distances, indices = self.index.search(query_emb, top_k * 3) # fetch more for filtering
33
-
34
- results = []
35
- for idx, dist in zip(indices[0], distances[0]):
36
- row = self.df.iloc[idx]
37
-
38
- # Normalized similarity: 0–1 (higher is better)
39
- similarity = round(1 / (1 + float(dist)), 3)
40
-
41
- results.append({
42
- "quote": row["quote"],
43
- "author": row["author"],
44
- "tags": row.get("tags", ""),
45
- "similarity": similarity
46
- })
47
-
48
- # Simple author filter if author name is in query
49
- query_lower = query.lower()
50
- author_filtered = [r for r in results if r["author"].lower() in query_lower]
51
-
52
- if author_filtered:
53
- results = author_filtered[:top_k]
54
- else:
55
- results = results[:top_k]
56
-
57
- return results
58
-
59
-
60
- if __name__ == "__main__":
61
- rag = QuoteRAG()
62
- query = "Quotes about insanity attributed to Einstein"
63
- results = rag.search(query)
64
- for r in results:
65
- print(r)
 
1
+ import faiss
2
+ import numpy as np
3
+ import pandas as pd
4
+ from sentence_transformers import SentenceTransformer
5
+ import re
6
+
7
+ class QuoteRAG:
8
+ def __init__(self, model_path="models/fine_tuned_model", data_path="english_quotes.csv"):
9
+ # Load model
10
+ try:
11
+ self.model = SentenceTransformer(model_path)
12
+ print("Loaded fine-tuned model")
13
+ except:
14
+ self.model = SentenceTransformer("all-MiniLM-L6-v2")
15
+ print("Loaded base model")
16
+
17
+ # Load dataset
18
+ self.df = pd.read_csv(data_path)
19
+
20
+ # Encode all quotes
21
+ self.embeddings = self.model.encode(self.df["quote"].tolist(), convert_to_numpy=True)
22
+ d = self.embeddings.shape[1]
23
+
24
+ # Build FAISS index
25
+ self.index = faiss.IndexFlatL2(d)
26
+ self.index.add(self.embeddings)
27
+ print("FAISS index built with", len(self.df), "quotes")
28
+
29
+ def search(self, query, top_k=5):
30
+ # Encode query
31
+ query_emb = self.model.encode([query], convert_to_numpy=True)
32
+ distances, indices = self.index.search(query_emb, top_k * 3) # fetch more for filtering
33
+
34
+ results = []
35
+ for idx, dist in zip(indices[0], distances[0]):
36
+ row = self.df.iloc[idx]
37
+
38
+ # Normalized similarity: 0–1 (higher is better)
39
+ similarity = round(1 / (1 + float(dist)), 3)
40
+
41
+ results.append({
42
+ "quote": row["quote"],
43
+ "author": row["author"],
44
+ "tags": row.get("tags", ""),
45
+ "similarity": similarity
46
+ })
47
+
48
+ # Simple author filter if author name is in query
49
+ query_lower = query.lower()
50
+ author_filtered = [r for r in results if r["author"].lower() in query_lower]
51
+
52
+ if author_filtered:
53
+ results = author_filtered[:top_k]
54
+ else:
55
+ results = results[:top_k]
56
+
57
+ return results
58
+
59
+
60
+ if __name__ == "__main__":
61
+ rag = QuoteRAG()
62
+ query = "Quotes about insanity attributed to Einstein"
63
+ results = rag.search(query)
64
+ for r in results:
65
+ print(r)