datasciencesage commited on
Commit
40e6b7a
·
verified ·
1 Parent(s): 5a1de22

Upload 5 files

Browse files
app/services/embedding_service.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import numpy as np
3
+ import os
4
+ from loguru import logger
5
+
6
+
7
+ class EmbeddingService:
8
+ """ Handles text embeddings using sentence transformers
9
+ Pretty straightforward - just wraps the model"""
10
+
11
+ def __init__(self, model_name=None):
12
+ try:
13
+ self.model_name = model_name or os.getenv("EMBED_MODEL", "all-MiniLM-L6-v2")
14
+ logger.info(f"Loading embedding model: {self.model_name}")
15
+ self.model = SentenceTransformer(self.model_name)
16
+ self.dimension = self.model.get_sentence_embedding_dimension()
17
+ logger.success(f"Model loaded. Embedding dimension: {self.dimension}")
18
+ except Exception as e:
19
+ logger.error(f"Failed to load embedding model: {str(e)}")
20
+ raise
21
+
22
+ def embed_single(self, text):
23
+ try:
24
+ if not text or not text.strip():
25
+ return np.zeros(self.dimension, dtype=np.float32)
26
+
27
+ embedding = self.model.encode(
28
+ text,
29
+ normalize_embeddings=True,
30
+ show_progress_bar=False,
31
+ convert_to_numpy=True
32
+ )
33
+ return embedding.astype(np.float32)
34
+ except Exception as e:
35
+ logger.error(f"Single embedding failed: {str(e)}")
36
+ return np.zeros(self.dimension, dtype=np.float32)
37
+
38
+ def embed_batch(self, texts, batch_size=32):
39
+ try:
40
+ if not texts:
41
+ return np.array([], dtype=np.float32)
42
+
43
+ valid_texts = []
44
+ valid_indices = []
45
+ for i, text in enumerate(texts):
46
+ if text and text.strip():
47
+ valid_texts.append(text)
48
+ valid_indices.append(i)
49
+
50
+ if not valid_texts:
51
+ return np.zeros((len(texts), self.dimension), dtype=np.float32)
52
+
53
+ # batch
54
+ embeddings = self.model.encode(
55
+ valid_texts,
56
+ batch_size=batch_size,
57
+ normalize_embeddings=True,
58
+ show_progress_bar=False,
59
+ convert_to_numpy=True
60
+ )
61
+
62
+ # put embeddings back in right positions
63
+ output = np.zeros((len(texts), self.dimension), dtype=np.float32)
64
+ for i, valid_idx in enumerate(valid_indices):
65
+ output[valid_idx] = embeddings[i]
66
+
67
+ return output.astype(np.float32)
68
+ except Exception as e:
69
+ logger.error(f"Batch embedding failed: {str(e)}")
70
+ return np.zeros((len(texts), self.dimension), dtype=np.float32)
71
+
72
+ def cosine_similarity(self, vec1, vec2):
73
+ # similarity between two vectors
74
+ try:
75
+ norm1 = np.linalg.norm(vec1)
76
+ norm2 = np.linalg.norm(vec2)
77
+
78
+ if norm1 == 0 or norm2 == 0:
79
+ return 0.0
80
+
81
+ similarity = np.dot(vec1, vec2) / (norm1 * norm2)
82
+ return float(np.clip(similarity, -1.0, 1.0))
83
+ except Exception as e:
84
+ logger.error(f"Cosine similarity failed: {str(e)}")
85
+ return 0.0
86
+
87
+ def batch_cosine_similarity(self, query_vec, corpus_vecs):
88
+ # compare one query against many vectors
89
+ # faster than looping through each one
90
+ try:
91
+ query_norm = query_vec / (np.linalg.norm(query_vec) + 1e-8)
92
+
93
+ # dot product = cosine sim for normalized vectors
94
+ similarities = np.dot(corpus_vecs, query_norm)
95
+ return np.clip(similarities, -1.0, 1.0)
96
+ except Exception as e:
97
+ logger.error(f"Batch cosine similarity failed: {str(e)}")
98
+ return np.zeros(len(corpus_vecs))
99
+
100
+ def get_dimension(self):
101
+ """Get embedding dimension"""
102
+ return self.dimension
app/services/key_mapper.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from rank_bm25 import BM25Okapi
3
+ import re
4
+ import os
5
+ from loguru import logger
6
+ from app.constants import SAMPLE_STORE_KEYS, build_key_search_text
7
+ from app.models import KeyMapping
8
+ from app.services.embedding_service import EmbeddingService
9
+
10
+
11
+ class KeyMapper:
12
+ """ Hybrid approach - combines semantic search with keyword matching
13
+ TODO: maybe add cross-encoder reranking later if needed """
14
+
15
+ def __init__(self, embedding_service):
16
+ try:
17
+ self.embed_service = embedding_service
18
+ self.rrf_k = int(os.getenv("RRF_K", "60")) # k=60 worked best in testing
19
+ self.threshold = float(os.getenv("SIM_THRESHOLD", "0.7"))
20
+
21
+ logger.info("Initializing KeyMapper...")
22
+
23
+ self.keys = SAMPLE_STORE_KEYS
24
+
25
+ # build text for each key to search against
26
+ self.key_texts = [build_key_search_text(k) for k in self.keys]
27
+ logger.debug(f"Built {len(self.key_texts)} key search texts")
28
+
29
+ # precompute embeddings so we dont have to do it every time
30
+ logger.info("Computing key embeddings...")
31
+ self.key_embeddings = self.embed_service.embed_batch(self.key_texts)
32
+ logger.debug(f"Key embeddings shape: {self.key_embeddings.shape}")
33
+
34
+ # setup BM25 for keyword matching
35
+ logger.info("Building BM25 index...")
36
+ self.tokenized_keys = [self.tokenize(text) for text in self.key_texts]
37
+ self.bm25 = BM25Okapi(self.tokenized_keys)
38
+ logger.success("KeyMapper initialized successfully")
39
+ except Exception as e:
40
+ logger.error(f"Failed to initialize KeyMapper: {str(e)}")
41
+ raise
42
+
43
+ def tokenize(self, text):
44
+ # simple tokenization - just split on word boundaries
45
+ try:
46
+ tokens = re.findall(r'\w+', text.lower())
47
+ return tokens
48
+ except Exception as e:
49
+ logger.error(f"Tokenization failed: {str(e)}")
50
+ return []
51
+
52
+ def extract_key_phrases(self, prompt):
53
+ # extract different phrase combinations from prompt
54
+ # helps match to specific parts of the prompt
55
+ try:
56
+ phrases = []
57
+
58
+ phrases.append(prompt.strip())
59
+
60
+ tokens = self.tokenize(prompt)
61
+
62
+ # bigrams - pairs of words
63
+ for i in range(len(tokens) - 1):
64
+ phrases.append(f"{tokens[i]} {tokens[i+1]}")
65
+
66
+ # trigrams - three word combos
67
+ for i in range(len(tokens) - 2):
68
+ phrases.append(f"{tokens[i]} {tokens[i+1]} {tokens[i+2]}")
69
+
70
+ # add longer tokens only (skip short words like 'is', 'or')
71
+ phrases.extend([t for t in tokens if len(t) > 3])
72
+
73
+ # remove dupes but keep order
74
+ seen = set()
75
+ unique = []
76
+ for p in phrases:
77
+ if p not in seen:
78
+ seen.add(p)
79
+ unique.append(p)
80
+
81
+ return unique[:15] # limit to avoid too many
82
+ except Exception as e:
83
+ logger.error(f"Phrase extraction failed: {str(e)}")
84
+ return [prompt] # fallback to just the prompt
85
+
86
+ def compute_dense_ranks(self, prompt):
87
+ # get semantic similarity using embeddings
88
+ try:
89
+ prompt_emb = self.embed_service.embed_single(prompt)
90
+
91
+ similarities = self.embed_service.batch_cosine_similarity(
92
+ prompt_emb,
93
+ self.key_embeddings
94
+ )
95
+
96
+ # sort by similarity
97
+ ranks = np.argsort(-similarities)
98
+
99
+ # convert to rank positions starting from 1
100
+ rank_positions = np.zeros(len(self.keys), dtype=int)
101
+ for pos, idx in enumerate(ranks):
102
+ rank_positions[idx] = pos + 1
103
+
104
+ return rank_positions, similarities
105
+ except Exception as e:
106
+ logger.error(f"Dense ranking failed: {str(e)}")
107
+ # return default ranks if something breaks
108
+ default_ranks = np.arange(1, len(self.keys) + 1)
109
+ default_sims = np.zeros(len(self.keys))
110
+ return default_ranks, default_sims
111
+
112
+ def compute_sparse_ranks(self, prompt):
113
+ # keyword-based matching with BM25
114
+ try:
115
+ prompt_tokens = self.tokenize(prompt)
116
+ bm25_scores = self.bm25.get_scores(prompt_tokens)
117
+
118
+ ranks = np.argsort(-bm25_scores)
119
+
120
+ rank_positions = np.zeros(len(self.keys), dtype=int)
121
+ for pos, idx in enumerate(ranks):
122
+ rank_positions[idx] = pos + 1
123
+
124
+ return rank_positions, bm25_scores
125
+ except Exception as e:
126
+ logger.error(f"Sparse ranking failed: {str(e)}")
127
+ default_ranks = np.arange(1, len(self.keys) + 1)
128
+ default_scores = np.zeros(len(self.keys))
129
+ return default_ranks, default_scores
130
+
131
+ def apply_rrf(self, dense_ranks, sparse_ranks):
132
+ # reciprocal rank fusion - combines both ranking methods
133
+ # formula from research paper, works better than weighted average
134
+ try:
135
+ rrf_scores = (1.0 / (self.rrf_k + dense_ranks)) + \
136
+ (1.0 / (self.rrf_k + sparse_ranks))
137
+ return rrf_scores
138
+ except Exception as e:
139
+ logger.error(f"RRF fusion failed: {str(e)}")
140
+ # fallback to just dense ranks
141
+ return 1.0 / (self.rrf_k + dense_ranks)
142
+
143
+ def map_keys(self, prompt, top_k=5):
144
+ """Map user prompt to actual store keys"""
145
+ try:
146
+ logger.info(f"Mapping keys for prompt: {prompt[:50]}...")
147
+
148
+ # get rankings from both methods
149
+ dense_ranks, dense_sims = self.compute_dense_ranks(prompt)
150
+ sparse_ranks, sparse_scores = self.compute_sparse_ranks(prompt)
151
+
152
+ # combine them using RRF
153
+ rrf_scores = self.apply_rrf(dense_ranks, sparse_ranks)
154
+
155
+ # sort by combined score
156
+ sorted_indices = np.argsort(-rrf_scores)
157
+
158
+ # extract phrases from prompt
159
+ key_phrases = self.extract_key_phrases(prompt)
160
+
161
+ # build the mappings
162
+ mappings = []
163
+ for idx in sorted_indices:
164
+ # normalize score to 0-1 range
165
+ max_rrf = 2.0 / (self.rrf_k + 1)
166
+ normalized_score = float(rrf_scores[idx] / max_rrf)
167
+
168
+ # find which phrase matches this key best
169
+ key_emb = self.key_embeddings[idx]
170
+ best_phrase = prompt # default to full prompt
171
+ best_phrase_sim = dense_sims[idx]
172
+
173
+ # check each phrase
174
+ for phrase in key_phrases:
175
+ phrase_emb = self.embed_service.embed_single(phrase)
176
+ phrase_sim = self.embed_service.cosine_similarity(phrase_emb, key_emb)
177
+ if phrase_sim > best_phrase_sim:
178
+ best_phrase = phrase
179
+ best_phrase_sim = phrase_sim
180
+
181
+ mappings.append(KeyMapping(
182
+ user_phrase=best_phrase[:50],
183
+ mapped_to=self.keys[idx]['value'],
184
+ similarity=float(np.clip(normalized_score, 0.0, 1.0))
185
+ ))
186
+
187
+ if len(mappings) >= top_k:
188
+ break
189
+
190
+ logger.success(f"Mapped {len(mappings)} keys successfully")
191
+ return mappings
192
+
193
+ except Exception as e:
194
+ logger.error(f"Key mapping failed: {str(e)}")
195
+ # return empty list if everything breaks
196
+ return []
197
+
198
+ def get_top_keys(self, prompt, top_k=5, min_similarity=None):
199
+ """Get top keys with full metadata"""
200
+ try:
201
+ threshold = min_similarity if min_similarity is not None else self.threshold
202
+
203
+ # get more than needed then filter
204
+ mappings = self.map_keys(prompt, top_k=top_k * 2)
205
+
206
+ filtered = [m for m in mappings if m.similarity >= threshold]
207
+
208
+ # add full key details
209
+ result = []
210
+ for mapping in filtered[:top_k]:
211
+ key_obj = next((k for k in self.keys if k['value'] == mapping.mapped_to), None)
212
+ if key_obj:
213
+ result.append({
214
+ **key_obj,
215
+ 'similarity': mapping.similarity,
216
+ 'matched_phrase': mapping.user_phrase
217
+ })
218
+
219
+ return result
220
+ except Exception as e:
221
+ logger.error(f"get_top_keys failed: {str(e)}")
222
+ return []
app/services/rag_service.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import faiss
3
+ import os
4
+ from openai import OpenAI
5
+ from loguru import logger
6
+ from app.services.embedding_service import EmbeddingService
7
+ from app.constants import POLICIES
8
+
9
+
10
+ class RAGService:
11
+ """ Handles policy retrieval with FAISS
12
+ CRAG = corrective RAG, basically retries if results are bad """
13
+
14
+ def __init__(self, embedding_service):
15
+ try:
16
+ self.embed_service = embedding_service
17
+ self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
18
+
19
+ logger.info("Initializing RAG Service...")
20
+
21
+ self.policies = POLICIES.copy()
22
+ self.build_index()
23
+
24
+ logger.success(f"RAG Service initialized with {len(self.policies)} policies")
25
+ except Exception as e:
26
+ logger.error(f"Failed to initialize RAG Service: {str(e)}")
27
+ raise
28
+
29
+ def build_index(self):
30
+ """Build FAISS index from policy docs"""
31
+ try:
32
+ if not self.policies:
33
+ logger.warning("No policies to index")
34
+ self.index = None
35
+ self.policy_embeddings = None
36
+ return
37
+
38
+ logger.info(f"Embedding {len(self.policies)} policy documents...")
39
+
40
+ # embed all policies
41
+ self.policy_embeddings = self.embed_service.embed_batch(self.policies)
42
+
43
+ # FAISS index - using inner product since vectors are normalized
44
+ dimension = self.embed_service.get_dimension()
45
+ self.index = faiss.IndexFlatIP(dimension)
46
+
47
+ self.index.add(self.policy_embeddings.astype('float32'))
48
+
49
+ logger.debug(f"FAISS index built with {self.index.ntotal} vectors")
50
+ except Exception as e:
51
+ logger.error(f"Index building failed: {str(e)}")
52
+ self.index = None
53
+ self.policy_embeddings = None
54
+
55
+ def add_documents(self, new_docs):
56
+ """Add new docs to index on the fly"""
57
+ try:
58
+ if not new_docs:
59
+ return
60
+
61
+ logger.info(f"Adding {len(new_docs)} temporary documents...")
62
+
63
+ new_embeddings = self.embed_service.embed_batch(new_docs)
64
+ self.index.add(new_embeddings.astype('float32'))
65
+ self.policies.extend(new_docs)
66
+
67
+ # stack new embeddings with old ones
68
+ self.policy_embeddings = np.vstack([self.policy_embeddings, new_embeddings])
69
+
70
+ logger.debug(f"Index now contains {self.index.ntotal} documents")
71
+ except Exception as e:
72
+ logger.error(f"Failed to add documents: {str(e)}")
73
+
74
+ def retrieve(self, query, top_k=3):
75
+ """Basic retrieval from FAISS"""
76
+ try:
77
+ if self.index is None or self.index.ntotal == 0:
78
+ logger.warning("Index is empty, returning no results")
79
+ return []
80
+
81
+ # embed and search
82
+ query_emb = self.embed_service.embed_single(query).reshape(1, -1)
83
+ scores, indices = self.index.search(query_emb.astype('float32'), top_k)
84
+
85
+ results = []
86
+ for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
87
+ if idx < len(self.policies):
88
+ results.append({
89
+ 'text': self.policies[idx],
90
+ 'score': float(score),
91
+ 'index': int(idx),
92
+ 'rank': i + 1
93
+ })
94
+
95
+ logger.debug(f"Retrieved {len(results)} documents")
96
+ return results
97
+ except Exception as e:
98
+ logger.error(f"Retrieval failed: {str(e)}")
99
+ return []
100
+
101
+ def judge_relevance(self, query, documents):
102
+ # use llm to score how relevant each doc is
103
+ # helps filter out garbage results
104
+ try:
105
+ if not documents:
106
+ return []
107
+
108
+ doc_texts = "\n\n".join([
109
+ f"DOCUMENT {i+1}:\n{doc['text']}"
110
+ for i, doc in enumerate(documents)
111
+ ])
112
+
113
+ judge_prompt = f"""You are an expert relevance evaluator for a loan application rule generation system.
114
+
115
+ QUERY: {query}
116
+
117
+ RETRIEVED DOCUMENTS:
118
+ {doc_texts}
119
+
120
+ Task: Rate the relevance of each document to the query on a scale of 0.0 to 1.0.
121
+ - 1.0 = Highly relevant, directly helps answer the query
122
+ - 0.5 = Somewhat relevant, provides context
123
+ - 0.0 = Not relevant at all
124
+
125
+ Respond ONLY with a JSON array of scores, one per document in order.
126
+ Example: [0.9, 0.6, 0.2]
127
+
128
+ Scores:"""
129
+
130
+ response = self.client.chat.completions.create(
131
+ model="gpt-4o-mini",
132
+ messages=[
133
+ {"role": "system", "content": "You are a relevance scoring expert. Respond only with a JSON array of numbers."},
134
+ {"role": "user", "content": judge_prompt}
135
+ ],
136
+ temperature=0.1,
137
+ max_tokens=100
138
+ )
139
+
140
+ content = response.choices[0].message.content.strip()
141
+ import json
142
+ scores = json.loads(content)
143
+
144
+ # clamp to 0-1 range
145
+ scores = [max(0.0, min(1.0, float(s))) for s in scores]
146
+
147
+ # pad if llm didnt return enough scores
148
+ while len(scores) < len(documents):
149
+ scores.append(0.5)
150
+
151
+ logger.debug(f"LLM judge scores: {scores}")
152
+ return scores[:len(documents)]
153
+
154
+ except Exception as e:
155
+ logger.error(f"LLM judge failed: {str(e)}")
156
+ # fallback - just use retrieval scores
157
+ return [doc['score'] / (doc['score'] + 1.0) for doc in documents]
158
+
159
+ def refine_query(self, original_query, low_relevance_docs):
160
+ # if results are bad, ask llm to rewrite the query
161
+ # usually helps by adding more specific terms
162
+ try:
163
+ refine_prompt = f"""Original query: "{original_query}"
164
+
165
+ The retrieved documents were not very relevant. Suggest a better search query that focuses on key loan application terms like:
166
+ - Bureau score, credit score, CIBIL
167
+ - Business vintage, age
168
+ - Overdue amounts, DPD
169
+ - Income, FOIR
170
+ - GST, banking metrics
171
+
172
+ Respond with ONLY the improved query, no explanation.
173
+
174
+ Improved query:"""
175
+
176
+ response = self.client.chat.completions.create(
177
+ model="gpt-4o-mini",
178
+ messages=[
179
+ {"role": "system", "content": "You are a query refinement expert for loan application rules."},
180
+ {"role": "user", "content": refine_prompt}
181
+ ],
182
+ temperature=0.3,
183
+ max_tokens=100
184
+ )
185
+
186
+ refined = response.choices[0].message.content.strip().strip('"')
187
+ logger.info(f"Refined query: {refined}")
188
+ return refined if refined else original_query
189
+
190
+ except Exception as e:
191
+ logger.error(f"Query refinement failed: {str(e)}")
192
+ return original_query
193
+
194
+ def retrieve_with_crag(self, query, top_k=2, relevance_threshold=0.7):
195
+ """
196
+ CRAG = Corrective RAG
197
+ retrieves docs, checks if theyre good, retries if not
198
+ """
199
+ try:
200
+ logger.info(f"CRAG: Retrieving for query: '{query[:50]}...'")
201
+
202
+ docs = self.retrieve(query, top_k=top_k)
203
+
204
+ if not docs:
205
+ logger.warning("No documents retrieved")
206
+ return [], 0.0
207
+
208
+ # judge how relevant results are
209
+ relevance_scores = self.judge_relevance(query, docs)
210
+
211
+ for doc, score in zip(docs, relevance_scores):
212
+ doc['relevance'] = score
213
+
214
+ avg_relevance = np.mean(relevance_scores)
215
+ logger.debug(f"CRAG: Initial relevance: {avg_relevance:.3f}")
216
+
217
+ # if relevance sucks, refine and try again
218
+ if avg_relevance < relevance_threshold:
219
+ logger.info("CRAG: Low relevance detected, refining query...")
220
+
221
+ refined_query = self.refine_query(query, [d['text'] for d in docs])
222
+ logger.debug(f"CRAG: Refined query: '{refined_query[:50]}...'")
223
+
224
+ # try again with better query
225
+ refined_docs = self.retrieve(refined_query, top_k=top_k)
226
+
227
+ if refined_docs:
228
+ refined_relevance = self.judge_relevance(refined_query, refined_docs)
229
+
230
+ for doc, score in zip(refined_docs, refined_relevance):
231
+ doc['relevance'] = score
232
+
233
+ refined_avg = np.mean(refined_relevance)
234
+ logger.debug(f"CRAG: Refined relevance: {refined_avg:.3f}")
235
+
236
+ # use refined results only if theyre better
237
+ if refined_avg > avg_relevance:
238
+ docs = refined_docs
239
+ avg_relevance = refined_avg
240
+ logger.info("CRAG: Using refined results")
241
+ else:
242
+ logger.info("CRAG: Keeping original results")
243
+
244
+ # sort by relevance score
245
+ docs.sort(key=lambda x: x['relevance'], reverse=True)
246
+
247
+ return docs, avg_relevance
248
+
249
+ except Exception as e:
250
+ logger.error(f"CRAG failed: {str(e)}")
251
+ return [], 0.0
252
+
253
+ def format_context(self, documents, max_length=500):
254
+ """Format docs into a string for llm context"""
255
+ try:
256
+ if not documents:
257
+ return "No relevant policies found."
258
+
259
+ context_parts = []
260
+ for i, doc in enumerate(documents):
261
+ text = doc['text'][:max_length]
262
+ relevance = doc.get('relevance', doc.get('score', 0))
263
+ context_parts.append(
264
+ f"Policy {i+1} (relevance: {relevance:.2f}):\n{text}"
265
+ )
266
+
267
+ return "\n\n".join(context_parts)
268
+ except Exception as e:
269
+ logger.error(f"Context formatting failed: {str(e)}")
270
+ return "Error formatting policy context."
app/services/rule_service.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from openai import OpenAI
4
+ from json_logic import jsonLogic
5
+ from loguru import logger
6
+ from app.constants import MOCK_STORE_SAMPLES
7
+ from app.models import KeyMapping
8
+
9
+
10
+ class RuleGenerationService:
11
+ """ Generates json logic rules using gpt-4o-mini
12
+ Uses self-consistency voting to pick best rule from multiple attempts
13
+ """
14
+
15
+ def __init__(self):
16
+ try:
17
+ self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
18
+ self.model = "gpt-4o-mini"
19
+ logger.success("RuleGenerationService initialized")
20
+ except Exception as e:
21
+ logger.error(f"Failed to initialize RuleGenerationService: {str(e)}")
22
+ raise
23
+
24
+ def build_system_prompt(self, available_keys, policy_context):
25
+
26
+ try:
27
+ keys_str = json.dumps(available_keys, indent=2)
28
+
29
+ system_prompt = f"""You are an expert JSON Logic rule generator for loan application systems.
30
+
31
+ AVAILABLE KEYS (use ONLY these in {{"var": "key"}}):
32
+ {keys_str}
33
+
34
+ POLICY CONTEXT:
35
+ {policy_context}
36
+
37
+ JSON LOGIC OPERATORS:
38
+ - Logical: "and", "or", "!", "if"
39
+ - Comparison: ">", "<", ">=", "<=", "==", "!="
40
+ - Arrays: "in", "some", "all"
41
+ - Math: "+", "-", "*", "/"
42
+
43
+ RULES:
44
+ 1. Use ONLY the available keys listed above
45
+ 2. All keys must be referenced using {{"var": "key.name"}}
46
+ 3. Generate valid JSON Logic syntax
47
+ 4. Be precise with thresholds from policies
48
+ 5. Use "and" for multiple conditions, "or" for alternatives
49
+
50
+ OUTPUT FORMAT (must be valid JSON):
51
+ {{
52
+ "json_logic": {{"and": [...]}},
53
+ "explanation": "Brief 1-2 sentence explanation",
54
+ "used_keys": ["key1", "key2"],
55
+ "confidence": 0.0-1.0
56
+ }}
57
+
58
+ EXAMPLES:
59
+
60
+ User: "Approve if bureau score > 700"
61
+ Output:
62
+ {{
63
+ "json_logic": {{">": [{{"var": "bureau.score"}}, 700]}},
64
+ "explanation": "Approves applications where bureau score exceeds 700.",
65
+ "used_keys": ["bureau.score"],
66
+ "confidence": 0.95
67
+ }}
68
+
69
+ User: "Reject if wilful default OR suit filed"
70
+ Output:
71
+ {{
72
+ "json_logic": {{"or": [
73
+ {{"==": [{{"var": "bureau.wilful_default"}}, true]}},
74
+ {{"==": [{{"var": "bureau.suit_filed"}}, true]}}
75
+ ]}},
76
+ "explanation": "Rejects applications with wilful default or suit filed status.",
77
+ "used_keys": ["bureau.wilful_default", "bureau.suit_filed"],
78
+ "confidence": 0.92
79
+ }}
80
+
81
+ User: "Approve if age between 25 and 60"
82
+ Output:
83
+ {{
84
+ "json_logic": {{"and": [
85
+ {{">=": [{{"var": "primary_applicant.age"}}, 25]}},
86
+ {{"<=": [{{"var": "primary_applicant.age"}}, 60]}}
87
+ ]}},
88
+ "explanation": "Approves when primary applicant age is between 25 and 60 years inclusive.",
89
+ "used_keys": ["primary_applicant.age"],
90
+ "confidence": 0.90
91
+ }}"""
92
+
93
+ return system_prompt
94
+ except Exception as e:
95
+ logger.error(f"Failed to build system prompt: {str(e)}")
96
+ return ""
97
+
98
+ def generate_single_rule(self, prompt, system_prompt, temperature=0.2):
99
+ # calls llm once to get a rule
100
+ # temperature controls randomness - lower = more deterministic
101
+ try:
102
+ response = self.client.chat.completions.create(
103
+ model=self.model,
104
+ messages=[
105
+ {"role": "system", "content": system_prompt},
106
+ {"role": "user", "content": prompt}
107
+ ],
108
+ temperature=temperature,
109
+ max_tokens=800,
110
+ response_format={"type": "json_object"}
111
+ )
112
+
113
+ content = response.choices[0].message.content
114
+ result = json.loads(content)
115
+
116
+ # check if all required fields are there
117
+ if not all(k in result for k in ['json_logic', 'explanation', 'used_keys']):
118
+ logger.warning("LLM response missing required fields")
119
+ raise ValueError("Missing required fields in LLM response")
120
+
121
+ # add default confidence if llm forgot to include it
122
+ if 'confidence' not in result:
123
+ result['confidence'] = 0.8
124
+
125
+ return result
126
+
127
+ except Exception as e:
128
+ logger.error(f"Rule generation failed: {str(e)}")
129
+ return None
130
+
131
+ def validate_rule(self, rule, available_keys):
132
+ # checks if rule only uses allowed keys
133
+ # extracts all {"var": "..."} and validates them
134
+ try:
135
+ if not rule:
136
+ return False, "Empty rule"
137
+
138
+ # recursively find all var references
139
+ def extract_vars(obj):
140
+ vars_found = []
141
+ if isinstance(obj, dict):
142
+ if "var" in obj:
143
+ vars_found.append(obj["var"])
144
+ for value in obj.values():
145
+ vars_found.extend(extract_vars(value))
146
+ elif isinstance(obj, list):
147
+ for item in obj:
148
+ vars_found.extend(extract_vars(item))
149
+ return vars_found
150
+
151
+ used_vars = extract_vars(rule)
152
+
153
+ # check all vars are in allowed list
154
+ invalid_vars = [v for v in used_vars if v not in available_keys]
155
+ if invalid_vars:
156
+ logger.warning(f"Rule uses invalid keys: {invalid_vars}")
157
+ return False, f"Invalid keys used: {invalid_vars}"
158
+
159
+ return True, ""
160
+
161
+ except Exception as e:
162
+ logger.error(f"Rule validation failed: {str(e)}")
163
+ return False, str(e)
164
+
165
+ def test_rule_on_mocks(self, rule, num_samples=5):
166
+ # runs the rule against mock data to see if it breaks
167
+ # doesn't check correctness, just that it executes
168
+ try:
169
+ if not rule:
170
+ return 0.0
171
+
172
+ successes = 0
173
+ samples = MOCK_STORE_SAMPLES[:num_samples]
174
+
175
+ for sample in samples:
176
+ try:
177
+ # apply json logic rule
178
+ result = jsonLogic(rule, sample)
179
+ successes += 1
180
+ except Exception as e:
181
+ # rule broke on this sample
182
+ logger.debug(f"Rule test failed on sample: {str(e)}")
183
+ continue
184
+
185
+ success_rate = successes / len(samples) if samples else 0.0
186
+ return success_rate
187
+
188
+ except Exception as e:
189
+ logger.error(f"Mock testing failed: {str(e)}")
190
+ return 0.0
191
+
192
+ def self_consistency_vote(self, variants):
193
+ # picks the best rule from multiple variants
194
+ # scores based on confidence, validation rate, and simplicity
195
+ try:
196
+ if not variants:
197
+ return None
198
+
199
+ if len(variants) == 1:
200
+ return variants[0]
201
+
202
+ scored_variants = []
203
+ for variant in variants:
204
+ score = 0.0
205
+
206
+ # llm's own confidence
207
+ score += variant.get('confidence', 0.5) * 0.4
208
+
209
+ # how well it ran on mock data
210
+ validation_rate = variant.get('validation_rate', 0.0)
211
+ score += validation_rate * 0.4
212
+
213
+ # prefer simpler rules (less json length)
214
+ rule_str = json.dumps(variant['json_logic'])
215
+ complexity_penalty = min(len(rule_str) / 500, 0.2)
216
+ score += (0.2 - complexity_penalty)
217
+
218
+ scored_variants.append((score, variant))
219
+
220
+ # sort by score descending
221
+ scored_variants.sort(key=lambda x: x[0], reverse=True)
222
+
223
+ scores_str = [f'{s:.3f}' for s, _ in scored_variants]
224
+ logger.debug(f"Self-consistency scores: {scores_str}")
225
+
226
+ return scored_variants[0][1] # return best one
227
+
228
+ except Exception as e:
229
+ logger.error(f"Self-consistency voting failed: {str(e)}")
230
+ return variants[0] if variants else None
231
+
232
+ def generate_rule(self, prompt, key_mappings, policy_context, num_variants=3):
233
+ """
234
+ Main method - generates rule with self-consistency
235
+ tries multiple times with different temperatures and picks best
236
+ """
237
+ try:
238
+ logger.info(f"Generating {num_variants} rule variants...")
239
+
240
+ # get list of allowed keys
241
+ available_keys = [m.mapped_to for m in key_mappings]
242
+
243
+ # build the big system prompt
244
+ system_prompt = self.build_system_prompt(available_keys, policy_context)
245
+
246
+ # generate multiple variants
247
+ variants = []
248
+ temperatures = [0.1, 0.3, 0.5][:num_variants]
249
+
250
+ for i, temp in enumerate(temperatures):
251
+ logger.debug(f"Generating variant {i+1} with temp={temp}...")
252
+
253
+ result = self.generate_single_rule(prompt, system_prompt, temperature=temp)
254
+
255
+ if result:
256
+ # validate it uses correct keys
257
+ is_valid, error_msg = self.validate_rule(
258
+ result['json_logic'],
259
+ available_keys
260
+ )
261
+
262
+ if not is_valid:
263
+ logger.warning(f"Variant {i+1} validation failed: {error_msg}")
264
+ continue
265
+
266
+ # test on mock data
267
+ validation_rate = self.test_rule_on_mocks(result['json_logic'])
268
+ result['validation_rate'] = validation_rate
269
+
270
+ logger.info(f"Variant {i+1}: conf={result['confidence']:.3f}, val={validation_rate:.3f}")
271
+
272
+ variants.append(result)
273
+
274
+ if not variants:
275
+ logger.error("Failed to generate any valid rule variants")
276
+ raise ValueError("Failed to generate any valid rule variants")
277
+
278
+ # vote for best rule
279
+ best_rule = self.self_consistency_vote(variants)
280
+
281
+ logger.success(f"Selected best rule")
282
+
283
+ return best_rule
284
+
285
+ except Exception as e:
286
+ logger.error(f"Rule generation failed: {str(e)}")
287
+ raise
288
+
289
+ def calculate_confidence_score(self, rule_result, key_mappings, policy_relevance):
290
+ """Calculate overall confidence from multiple factors"""
291
+ try:
292
+ # how well keys matched (40%)
293
+ avg_key_sim = sum(m.similarity for m in key_mappings) / len(key_mappings) if key_mappings else 0.0
294
+
295
+ # how relevant policies were (30%)
296
+ policy_score = policy_relevance
297
+
298
+ # llm confidence + validation rate (30%)
299
+ llm_confidence = rule_result.get('confidence', 0.8)
300
+ validation_rate = rule_result.get('validation_rate', 0.8)
301
+ generation_score = (llm_confidence + validation_rate) / 2
302
+
303
+ # weighted average
304
+ confidence = (
305
+ avg_key_sim * 0.4 +
306
+ policy_score * 0.3 +
307
+ generation_score * 0.3
308
+ )
309
+
310
+ # clamp to 0-1
311
+ return float(min(max(confidence, 0.0), 1.0))
312
+
313
+ except Exception as e:
314
+ logger.error(f"Confidence calculation failed: {str(e)}")
315
+ return 0.5 # default fallback