pluto90 commited on
Commit
ab4db24
·
verified ·
1 Parent(s): 0f97f06

Upload rag_service.py

Browse files
Files changed (1) hide show
  1. app/core/rag_service.py +98 -0
app/core/rag_service.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/core/rag_service.py
2
+
3
+ from app.core.embedding_engine import embedder, COLLECTION_NAME
4
+ from qdrant_client.http.models import Filter, FieldCondition, MatchValue
5
+ from qdrant_client import QdrantClient
6
+ from app.core.config import QDRANT_URL, QDRANT_API_KEY
7
+
8
+ qdrant_client = QdrantClient(
9
+ url=QDRANT_URL,
10
+ api_key=QDRANT_API_KEY,
11
+ check_compatibility=False
12
+ )
13
+
14
+
15
+ # def get_rag_context(question: str, doc_id: str):
16
+ # question_vector = embedder.encode([question])[0].tolist()
17
+
18
+ # hits = qdrant_client.query_points(
19
+ # collection_name=COLLECTION_NAME,
20
+ # query=question_vector,
21
+ # query_filter=Filter(
22
+ # must=[FieldCondition(key="doc_id", match=MatchValue(value=doc_id))]
23
+ # ),
24
+ # limit=5,
25
+ # ).points
26
+
27
+ # # context = "\n".join([hit.payload["text"] for hit in hits])
28
+
29
+ # contexts = []
30
+ # sources = []
31
+
32
+ # for hit in hits:
33
+ # text = hit.payload.get("text", "")
34
+ # contexts.append(text)
35
+
36
+ # sources.append({
37
+ # "text": text[:300], # limit for UI
38
+ # # add page if you have it later
39
+ # })
40
+
41
+ # context = "\n".join(contexts)
42
+
43
+ # return context, sources
44
+
45
+
46
+ # def get_rag_context(query, doc_id, top_k=3):
47
+ # query_vector = embedder.encode(query).tolist()
48
+
49
+ # results = qdrant_client.query_points(
50
+ # collection_name=doc_id,
51
+ # query=query_vector,
52
+ # limit=top_k
53
+ # )
54
+
55
+ # points = results.points
56
+
57
+ # if not points:
58
+ # return "", [], []
59
+
60
+ # context = "\n".join([p.payload["text"] for p in points])
61
+ # sources = [p.payload.get("source") for p in points]
62
+ # scores = [p.score for p in points]
63
+
64
+ # return context, sources, scores
65
+
66
+
67
+
68
+ def get_rag_context(query, doc_id, top_k=3):
69
+
70
+ # ✅ Embed query
71
+ query_vector = embedder.encode(query).tolist()
72
+
73
+ # ✅ Query SINGLE collection + filter by doc_id
74
+ results = qdrant_client.query_points(
75
+ collection_name="smartnotes", # 🔥 FIXED
76
+ query=query_vector,
77
+ limit=top_k,
78
+ query_filter=Filter(
79
+ must=[
80
+ FieldCondition(
81
+ key="doc_id",
82
+ match=MatchValue(value=doc_id)
83
+ )
84
+ ]
85
+ )
86
+ )
87
+
88
+ points = results.points
89
+
90
+ if not points:
91
+ return "", [], []
92
+
93
+ context = "\n".join([p.payload["text"] for p in points])
94
+ sources = [p.payload.get("source") for p in points]
95
+ scores = [p.score for p in points]
96
+
97
+ return context, sources, scores
98
+