Upload 5 files
Browse files- app/services/embedding_service.py +102 -0
- app/services/key_mapper.py +222 -0
- app/services/rag_service.py +270 -0
- app/services/rule_service.py +315 -0
app/services/embedding_service.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import SentenceTransformer
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from loguru import logger
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EmbeddingService:
|
| 8 |
+
""" Handles text embeddings using sentence transformers
|
| 9 |
+
Pretty straightforward - just wraps the model"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, model_name=None):
|
| 12 |
+
try:
|
| 13 |
+
self.model_name = model_name or os.getenv("EMBED_MODEL", "all-MiniLM-L6-v2")
|
| 14 |
+
logger.info(f"Loading embedding model: {self.model_name}")
|
| 15 |
+
self.model = SentenceTransformer(self.model_name)
|
| 16 |
+
self.dimension = self.model.get_sentence_embedding_dimension()
|
| 17 |
+
logger.success(f"Model loaded. Embedding dimension: {self.dimension}")
|
| 18 |
+
except Exception as e:
|
| 19 |
+
logger.error(f"Failed to load embedding model: {str(e)}")
|
| 20 |
+
raise
|
| 21 |
+
|
| 22 |
+
def embed_single(self, text):
|
| 23 |
+
try:
|
| 24 |
+
if not text or not text.strip():
|
| 25 |
+
return np.zeros(self.dimension, dtype=np.float32)
|
| 26 |
+
|
| 27 |
+
embedding = self.model.encode(
|
| 28 |
+
text,
|
| 29 |
+
normalize_embeddings=True,
|
| 30 |
+
show_progress_bar=False,
|
| 31 |
+
convert_to_numpy=True
|
| 32 |
+
)
|
| 33 |
+
return embedding.astype(np.float32)
|
| 34 |
+
except Exception as e:
|
| 35 |
+
logger.error(f"Single embedding failed: {str(e)}")
|
| 36 |
+
return np.zeros(self.dimension, dtype=np.float32)
|
| 37 |
+
|
| 38 |
+
def embed_batch(self, texts, batch_size=32):
|
| 39 |
+
try:
|
| 40 |
+
if not texts:
|
| 41 |
+
return np.array([], dtype=np.float32)
|
| 42 |
+
|
| 43 |
+
valid_texts = []
|
| 44 |
+
valid_indices = []
|
| 45 |
+
for i, text in enumerate(texts):
|
| 46 |
+
if text and text.strip():
|
| 47 |
+
valid_texts.append(text)
|
| 48 |
+
valid_indices.append(i)
|
| 49 |
+
|
| 50 |
+
if not valid_texts:
|
| 51 |
+
return np.zeros((len(texts), self.dimension), dtype=np.float32)
|
| 52 |
+
|
| 53 |
+
# batch
|
| 54 |
+
embeddings = self.model.encode(
|
| 55 |
+
valid_texts,
|
| 56 |
+
batch_size=batch_size,
|
| 57 |
+
normalize_embeddings=True,
|
| 58 |
+
show_progress_bar=False,
|
| 59 |
+
convert_to_numpy=True
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# put embeddings back in right positions
|
| 63 |
+
output = np.zeros((len(texts), self.dimension), dtype=np.float32)
|
| 64 |
+
for i, valid_idx in enumerate(valid_indices):
|
| 65 |
+
output[valid_idx] = embeddings[i]
|
| 66 |
+
|
| 67 |
+
return output.astype(np.float32)
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"Batch embedding failed: {str(e)}")
|
| 70 |
+
return np.zeros((len(texts), self.dimension), dtype=np.float32)
|
| 71 |
+
|
| 72 |
+
def cosine_similarity(self, vec1, vec2):
|
| 73 |
+
# similarity between two vectors
|
| 74 |
+
try:
|
| 75 |
+
norm1 = np.linalg.norm(vec1)
|
| 76 |
+
norm2 = np.linalg.norm(vec2)
|
| 77 |
+
|
| 78 |
+
if norm1 == 0 or norm2 == 0:
|
| 79 |
+
return 0.0
|
| 80 |
+
|
| 81 |
+
similarity = np.dot(vec1, vec2) / (norm1 * norm2)
|
| 82 |
+
return float(np.clip(similarity, -1.0, 1.0))
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"Cosine similarity failed: {str(e)}")
|
| 85 |
+
return 0.0
|
| 86 |
+
|
| 87 |
+
def batch_cosine_similarity(self, query_vec, corpus_vecs):
|
| 88 |
+
# compare one query against many vectors
|
| 89 |
+
# faster than looping through each one
|
| 90 |
+
try:
|
| 91 |
+
query_norm = query_vec / (np.linalg.norm(query_vec) + 1e-8)
|
| 92 |
+
|
| 93 |
+
# dot product = cosine sim for normalized vectors
|
| 94 |
+
similarities = np.dot(corpus_vecs, query_norm)
|
| 95 |
+
return np.clip(similarities, -1.0, 1.0)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"Batch cosine similarity failed: {str(e)}")
|
| 98 |
+
return np.zeros(len(corpus_vecs))
|
| 99 |
+
|
| 100 |
+
def get_dimension(self):
|
| 101 |
+
"""Get embedding dimension"""
|
| 102 |
+
return self.dimension
|
app/services/key_mapper.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from rank_bm25 import BM25Okapi
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
from loguru import logger
|
| 6 |
+
from app.constants import SAMPLE_STORE_KEYS, build_key_search_text
|
| 7 |
+
from app.models import KeyMapping
|
| 8 |
+
from app.services.embedding_service import EmbeddingService
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class KeyMapper:
|
| 12 |
+
""" Hybrid approach - combines semantic search with keyword matching
|
| 13 |
+
TODO: maybe add cross-encoder reranking later if needed """
|
| 14 |
+
|
| 15 |
+
def __init__(self, embedding_service):
|
| 16 |
+
try:
|
| 17 |
+
self.embed_service = embedding_service
|
| 18 |
+
self.rrf_k = int(os.getenv("RRF_K", "60")) # k=60 worked best in testing
|
| 19 |
+
self.threshold = float(os.getenv("SIM_THRESHOLD", "0.7"))
|
| 20 |
+
|
| 21 |
+
logger.info("Initializing KeyMapper...")
|
| 22 |
+
|
| 23 |
+
self.keys = SAMPLE_STORE_KEYS
|
| 24 |
+
|
| 25 |
+
# build text for each key to search against
|
| 26 |
+
self.key_texts = [build_key_search_text(k) for k in self.keys]
|
| 27 |
+
logger.debug(f"Built {len(self.key_texts)} key search texts")
|
| 28 |
+
|
| 29 |
+
# precompute embeddings so we dont have to do it every time
|
| 30 |
+
logger.info("Computing key embeddings...")
|
| 31 |
+
self.key_embeddings = self.embed_service.embed_batch(self.key_texts)
|
| 32 |
+
logger.debug(f"Key embeddings shape: {self.key_embeddings.shape}")
|
| 33 |
+
|
| 34 |
+
# setup BM25 for keyword matching
|
| 35 |
+
logger.info("Building BM25 index...")
|
| 36 |
+
self.tokenized_keys = [self.tokenize(text) for text in self.key_texts]
|
| 37 |
+
self.bm25 = BM25Okapi(self.tokenized_keys)
|
| 38 |
+
logger.success("KeyMapper initialized successfully")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.error(f"Failed to initialize KeyMapper: {str(e)}")
|
| 41 |
+
raise
|
| 42 |
+
|
| 43 |
+
def tokenize(self, text):
|
| 44 |
+
# simple tokenization - just split on word boundaries
|
| 45 |
+
try:
|
| 46 |
+
tokens = re.findall(r'\w+', text.lower())
|
| 47 |
+
return tokens
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Tokenization failed: {str(e)}")
|
| 50 |
+
return []
|
| 51 |
+
|
| 52 |
+
def extract_key_phrases(self, prompt):
|
| 53 |
+
# extract different phrase combinations from prompt
|
| 54 |
+
# helps match to specific parts of the prompt
|
| 55 |
+
try:
|
| 56 |
+
phrases = []
|
| 57 |
+
|
| 58 |
+
phrases.append(prompt.strip())
|
| 59 |
+
|
| 60 |
+
tokens = self.tokenize(prompt)
|
| 61 |
+
|
| 62 |
+
# bigrams - pairs of words
|
| 63 |
+
for i in range(len(tokens) - 1):
|
| 64 |
+
phrases.append(f"{tokens[i]} {tokens[i+1]}")
|
| 65 |
+
|
| 66 |
+
# trigrams - three word combos
|
| 67 |
+
for i in range(len(tokens) - 2):
|
| 68 |
+
phrases.append(f"{tokens[i]} {tokens[i+1]} {tokens[i+2]}")
|
| 69 |
+
|
| 70 |
+
# add longer tokens only (skip short words like 'is', 'or')
|
| 71 |
+
phrases.extend([t for t in tokens if len(t) > 3])
|
| 72 |
+
|
| 73 |
+
# remove dupes but keep order
|
| 74 |
+
seen = set()
|
| 75 |
+
unique = []
|
| 76 |
+
for p in phrases:
|
| 77 |
+
if p not in seen:
|
| 78 |
+
seen.add(p)
|
| 79 |
+
unique.append(p)
|
| 80 |
+
|
| 81 |
+
return unique[:15] # limit to avoid too many
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"Phrase extraction failed: {str(e)}")
|
| 84 |
+
return [prompt] # fallback to just the prompt
|
| 85 |
+
|
| 86 |
+
def compute_dense_ranks(self, prompt):
|
| 87 |
+
# get semantic similarity using embeddings
|
| 88 |
+
try:
|
| 89 |
+
prompt_emb = self.embed_service.embed_single(prompt)
|
| 90 |
+
|
| 91 |
+
similarities = self.embed_service.batch_cosine_similarity(
|
| 92 |
+
prompt_emb,
|
| 93 |
+
self.key_embeddings
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# sort by similarity
|
| 97 |
+
ranks = np.argsort(-similarities)
|
| 98 |
+
|
| 99 |
+
# convert to rank positions starting from 1
|
| 100 |
+
rank_positions = np.zeros(len(self.keys), dtype=int)
|
| 101 |
+
for pos, idx in enumerate(ranks):
|
| 102 |
+
rank_positions[idx] = pos + 1
|
| 103 |
+
|
| 104 |
+
return rank_positions, similarities
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.error(f"Dense ranking failed: {str(e)}")
|
| 107 |
+
# return default ranks if something breaks
|
| 108 |
+
default_ranks = np.arange(1, len(self.keys) + 1)
|
| 109 |
+
default_sims = np.zeros(len(self.keys))
|
| 110 |
+
return default_ranks, default_sims
|
| 111 |
+
|
| 112 |
+
def compute_sparse_ranks(self, prompt):
|
| 113 |
+
# keyword-based matching with BM25
|
| 114 |
+
try:
|
| 115 |
+
prompt_tokens = self.tokenize(prompt)
|
| 116 |
+
bm25_scores = self.bm25.get_scores(prompt_tokens)
|
| 117 |
+
|
| 118 |
+
ranks = np.argsort(-bm25_scores)
|
| 119 |
+
|
| 120 |
+
rank_positions = np.zeros(len(self.keys), dtype=int)
|
| 121 |
+
for pos, idx in enumerate(ranks):
|
| 122 |
+
rank_positions[idx] = pos + 1
|
| 123 |
+
|
| 124 |
+
return rank_positions, bm25_scores
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.error(f"Sparse ranking failed: {str(e)}")
|
| 127 |
+
default_ranks = np.arange(1, len(self.keys) + 1)
|
| 128 |
+
default_scores = np.zeros(len(self.keys))
|
| 129 |
+
return default_ranks, default_scores
|
| 130 |
+
|
| 131 |
+
def apply_rrf(self, dense_ranks, sparse_ranks):
|
| 132 |
+
# reciprocal rank fusion - combines both ranking methods
|
| 133 |
+
# formula from research paper, works better than weighted average
|
| 134 |
+
try:
|
| 135 |
+
rrf_scores = (1.0 / (self.rrf_k + dense_ranks)) + \
|
| 136 |
+
(1.0 / (self.rrf_k + sparse_ranks))
|
| 137 |
+
return rrf_scores
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.error(f"RRF fusion failed: {str(e)}")
|
| 140 |
+
# fallback to just dense ranks
|
| 141 |
+
return 1.0 / (self.rrf_k + dense_ranks)
|
| 142 |
+
|
| 143 |
+
def map_keys(self, prompt, top_k=5):
|
| 144 |
+
"""Map user prompt to actual store keys"""
|
| 145 |
+
try:
|
| 146 |
+
logger.info(f"Mapping keys for prompt: {prompt[:50]}...")
|
| 147 |
+
|
| 148 |
+
# get rankings from both methods
|
| 149 |
+
dense_ranks, dense_sims = self.compute_dense_ranks(prompt)
|
| 150 |
+
sparse_ranks, sparse_scores = self.compute_sparse_ranks(prompt)
|
| 151 |
+
|
| 152 |
+
# combine them using RRF
|
| 153 |
+
rrf_scores = self.apply_rrf(dense_ranks, sparse_ranks)
|
| 154 |
+
|
| 155 |
+
# sort by combined score
|
| 156 |
+
sorted_indices = np.argsort(-rrf_scores)
|
| 157 |
+
|
| 158 |
+
# extract phrases from prompt
|
| 159 |
+
key_phrases = self.extract_key_phrases(prompt)
|
| 160 |
+
|
| 161 |
+
# build the mappings
|
| 162 |
+
mappings = []
|
| 163 |
+
for idx in sorted_indices:
|
| 164 |
+
# normalize score to 0-1 range
|
| 165 |
+
max_rrf = 2.0 / (self.rrf_k + 1)
|
| 166 |
+
normalized_score = float(rrf_scores[idx] / max_rrf)
|
| 167 |
+
|
| 168 |
+
# find which phrase matches this key best
|
| 169 |
+
key_emb = self.key_embeddings[idx]
|
| 170 |
+
best_phrase = prompt # default to full prompt
|
| 171 |
+
best_phrase_sim = dense_sims[idx]
|
| 172 |
+
|
| 173 |
+
# check each phrase
|
| 174 |
+
for phrase in key_phrases:
|
| 175 |
+
phrase_emb = self.embed_service.embed_single(phrase)
|
| 176 |
+
phrase_sim = self.embed_service.cosine_similarity(phrase_emb, key_emb)
|
| 177 |
+
if phrase_sim > best_phrase_sim:
|
| 178 |
+
best_phrase = phrase
|
| 179 |
+
best_phrase_sim = phrase_sim
|
| 180 |
+
|
| 181 |
+
mappings.append(KeyMapping(
|
| 182 |
+
user_phrase=best_phrase[:50],
|
| 183 |
+
mapped_to=self.keys[idx]['value'],
|
| 184 |
+
similarity=float(np.clip(normalized_score, 0.0, 1.0))
|
| 185 |
+
))
|
| 186 |
+
|
| 187 |
+
if len(mappings) >= top_k:
|
| 188 |
+
break
|
| 189 |
+
|
| 190 |
+
logger.success(f"Mapped {len(mappings)} keys successfully")
|
| 191 |
+
return mappings
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logger.error(f"Key mapping failed: {str(e)}")
|
| 195 |
+
# return empty list if everything breaks
|
| 196 |
+
return []
|
| 197 |
+
|
| 198 |
+
def get_top_keys(self, prompt, top_k=5, min_similarity=None):
|
| 199 |
+
"""Get top keys with full metadata"""
|
| 200 |
+
try:
|
| 201 |
+
threshold = min_similarity if min_similarity is not None else self.threshold
|
| 202 |
+
|
| 203 |
+
# get more than needed then filter
|
| 204 |
+
mappings = self.map_keys(prompt, top_k=top_k * 2)
|
| 205 |
+
|
| 206 |
+
filtered = [m for m in mappings if m.similarity >= threshold]
|
| 207 |
+
|
| 208 |
+
# add full key details
|
| 209 |
+
result = []
|
| 210 |
+
for mapping in filtered[:top_k]:
|
| 211 |
+
key_obj = next((k for k in self.keys if k['value'] == mapping.mapped_to), None)
|
| 212 |
+
if key_obj:
|
| 213 |
+
result.append({
|
| 214 |
+
**key_obj,
|
| 215 |
+
'similarity': mapping.similarity,
|
| 216 |
+
'matched_phrase': mapping.user_phrase
|
| 217 |
+
})
|
| 218 |
+
|
| 219 |
+
return result
|
| 220 |
+
except Exception as e:
|
| 221 |
+
logger.error(f"get_top_keys failed: {str(e)}")
|
| 222 |
+
return []
|
app/services/rag_service.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import faiss
|
| 3 |
+
import os
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
from loguru import logger
|
| 6 |
+
from app.services.embedding_service import EmbeddingService
|
| 7 |
+
from app.constants import POLICIES
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RAGService:
|
| 11 |
+
""" Handles policy retrieval with FAISS
|
| 12 |
+
CRAG = corrective RAG, basically retries if results are bad """
|
| 13 |
+
|
| 14 |
+
def __init__(self, embedding_service):
|
| 15 |
+
try:
|
| 16 |
+
self.embed_service = embedding_service
|
| 17 |
+
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 18 |
+
|
| 19 |
+
logger.info("Initializing RAG Service...")
|
| 20 |
+
|
| 21 |
+
self.policies = POLICIES.copy()
|
| 22 |
+
self.build_index()
|
| 23 |
+
|
| 24 |
+
logger.success(f"RAG Service initialized with {len(self.policies)} policies")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
logger.error(f"Failed to initialize RAG Service: {str(e)}")
|
| 27 |
+
raise
|
| 28 |
+
|
| 29 |
+
def build_index(self):
|
| 30 |
+
"""Build FAISS index from policy docs"""
|
| 31 |
+
try:
|
| 32 |
+
if not self.policies:
|
| 33 |
+
logger.warning("No policies to index")
|
| 34 |
+
self.index = None
|
| 35 |
+
self.policy_embeddings = None
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
logger.info(f"Embedding {len(self.policies)} policy documents...")
|
| 39 |
+
|
| 40 |
+
# embed all policies
|
| 41 |
+
self.policy_embeddings = self.embed_service.embed_batch(self.policies)
|
| 42 |
+
|
| 43 |
+
# FAISS index - using inner product since vectors are normalized
|
| 44 |
+
dimension = self.embed_service.get_dimension()
|
| 45 |
+
self.index = faiss.IndexFlatIP(dimension)
|
| 46 |
+
|
| 47 |
+
self.index.add(self.policy_embeddings.astype('float32'))
|
| 48 |
+
|
| 49 |
+
logger.debug(f"FAISS index built with {self.index.ntotal} vectors")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.error(f"Index building failed: {str(e)}")
|
| 52 |
+
self.index = None
|
| 53 |
+
self.policy_embeddings = None
|
| 54 |
+
|
| 55 |
+
def add_documents(self, new_docs):
|
| 56 |
+
"""Add new docs to index on the fly"""
|
| 57 |
+
try:
|
| 58 |
+
if not new_docs:
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
logger.info(f"Adding {len(new_docs)} temporary documents...")
|
| 62 |
+
|
| 63 |
+
new_embeddings = self.embed_service.embed_batch(new_docs)
|
| 64 |
+
self.index.add(new_embeddings.astype('float32'))
|
| 65 |
+
self.policies.extend(new_docs)
|
| 66 |
+
|
| 67 |
+
# stack new embeddings with old ones
|
| 68 |
+
self.policy_embeddings = np.vstack([self.policy_embeddings, new_embeddings])
|
| 69 |
+
|
| 70 |
+
logger.debug(f"Index now contains {self.index.ntotal} documents")
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.error(f"Failed to add documents: {str(e)}")
|
| 73 |
+
|
| 74 |
+
def retrieve(self, query, top_k=3):
|
| 75 |
+
"""Basic retrieval from FAISS"""
|
| 76 |
+
try:
|
| 77 |
+
if self.index is None or self.index.ntotal == 0:
|
| 78 |
+
logger.warning("Index is empty, returning no results")
|
| 79 |
+
return []
|
| 80 |
+
|
| 81 |
+
# embed and search
|
| 82 |
+
query_emb = self.embed_service.embed_single(query).reshape(1, -1)
|
| 83 |
+
scores, indices = self.index.search(query_emb.astype('float32'), top_k)
|
| 84 |
+
|
| 85 |
+
results = []
|
| 86 |
+
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
|
| 87 |
+
if idx < len(self.policies):
|
| 88 |
+
results.append({
|
| 89 |
+
'text': self.policies[idx],
|
| 90 |
+
'score': float(score),
|
| 91 |
+
'index': int(idx),
|
| 92 |
+
'rank': i + 1
|
| 93 |
+
})
|
| 94 |
+
|
| 95 |
+
logger.debug(f"Retrieved {len(results)} documents")
|
| 96 |
+
return results
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.error(f"Retrieval failed: {str(e)}")
|
| 99 |
+
return []
|
| 100 |
+
|
| 101 |
+
def judge_relevance(self, query, documents):
|
| 102 |
+
# use llm to score how relevant each doc is
|
| 103 |
+
# helps filter out garbage results
|
| 104 |
+
try:
|
| 105 |
+
if not documents:
|
| 106 |
+
return []
|
| 107 |
+
|
| 108 |
+
doc_texts = "\n\n".join([
|
| 109 |
+
f"DOCUMENT {i+1}:\n{doc['text']}"
|
| 110 |
+
for i, doc in enumerate(documents)
|
| 111 |
+
])
|
| 112 |
+
|
| 113 |
+
judge_prompt = f"""You are an expert relevance evaluator for a loan application rule generation system.
|
| 114 |
+
|
| 115 |
+
QUERY: {query}
|
| 116 |
+
|
| 117 |
+
RETRIEVED DOCUMENTS:
|
| 118 |
+
{doc_texts}
|
| 119 |
+
|
| 120 |
+
Task: Rate the relevance of each document to the query on a scale of 0.0 to 1.0.
|
| 121 |
+
- 1.0 = Highly relevant, directly helps answer the query
|
| 122 |
+
- 0.5 = Somewhat relevant, provides context
|
| 123 |
+
- 0.0 = Not relevant at all
|
| 124 |
+
|
| 125 |
+
Respond ONLY with a JSON array of scores, one per document in order.
|
| 126 |
+
Example: [0.9, 0.6, 0.2]
|
| 127 |
+
|
| 128 |
+
Scores:"""
|
| 129 |
+
|
| 130 |
+
response = self.client.chat.completions.create(
|
| 131 |
+
model="gpt-4o-mini",
|
| 132 |
+
messages=[
|
| 133 |
+
{"role": "system", "content": "You are a relevance scoring expert. Respond only with a JSON array of numbers."},
|
| 134 |
+
{"role": "user", "content": judge_prompt}
|
| 135 |
+
],
|
| 136 |
+
temperature=0.1,
|
| 137 |
+
max_tokens=100
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
content = response.choices[0].message.content.strip()
|
| 141 |
+
import json
|
| 142 |
+
scores = json.loads(content)
|
| 143 |
+
|
| 144 |
+
# clamp to 0-1 range
|
| 145 |
+
scores = [max(0.0, min(1.0, float(s))) for s in scores]
|
| 146 |
+
|
| 147 |
+
# pad if llm didnt return enough scores
|
| 148 |
+
while len(scores) < len(documents):
|
| 149 |
+
scores.append(0.5)
|
| 150 |
+
|
| 151 |
+
logger.debug(f"LLM judge scores: {scores}")
|
| 152 |
+
return scores[:len(documents)]
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.error(f"LLM judge failed: {str(e)}")
|
| 156 |
+
# fallback - just use retrieval scores
|
| 157 |
+
return [doc['score'] / (doc['score'] + 1.0) for doc in documents]
|
| 158 |
+
|
| 159 |
+
def refine_query(self, original_query, low_relevance_docs):
|
| 160 |
+
# if results are bad, ask llm to rewrite the query
|
| 161 |
+
# usually helps by adding more specific terms
|
| 162 |
+
try:
|
| 163 |
+
refine_prompt = f"""Original query: "{original_query}"
|
| 164 |
+
|
| 165 |
+
The retrieved documents were not very relevant. Suggest a better search query that focuses on key loan application terms like:
|
| 166 |
+
- Bureau score, credit score, CIBIL
|
| 167 |
+
- Business vintage, age
|
| 168 |
+
- Overdue amounts, DPD
|
| 169 |
+
- Income, FOIR
|
| 170 |
+
- GST, banking metrics
|
| 171 |
+
|
| 172 |
+
Respond with ONLY the improved query, no explanation.
|
| 173 |
+
|
| 174 |
+
Improved query:"""
|
| 175 |
+
|
| 176 |
+
response = self.client.chat.completions.create(
|
| 177 |
+
model="gpt-4o-mini",
|
| 178 |
+
messages=[
|
| 179 |
+
{"role": "system", "content": "You are a query refinement expert for loan application rules."},
|
| 180 |
+
{"role": "user", "content": refine_prompt}
|
| 181 |
+
],
|
| 182 |
+
temperature=0.3,
|
| 183 |
+
max_tokens=100
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
refined = response.choices[0].message.content.strip().strip('"')
|
| 187 |
+
logger.info(f"Refined query: {refined}")
|
| 188 |
+
return refined if refined else original_query
|
| 189 |
+
|
| 190 |
+
except Exception as e:
|
| 191 |
+
logger.error(f"Query refinement failed: {str(e)}")
|
| 192 |
+
return original_query
|
| 193 |
+
|
| 194 |
+
def retrieve_with_crag(self, query, top_k=2, relevance_threshold=0.7):
|
| 195 |
+
"""
|
| 196 |
+
CRAG = Corrective RAG
|
| 197 |
+
retrieves docs, checks if theyre good, retries if not
|
| 198 |
+
"""
|
| 199 |
+
try:
|
| 200 |
+
logger.info(f"CRAG: Retrieving for query: '{query[:50]}...'")
|
| 201 |
+
|
| 202 |
+
docs = self.retrieve(query, top_k=top_k)
|
| 203 |
+
|
| 204 |
+
if not docs:
|
| 205 |
+
logger.warning("No documents retrieved")
|
| 206 |
+
return [], 0.0
|
| 207 |
+
|
| 208 |
+
# judge how relevant results are
|
| 209 |
+
relevance_scores = self.judge_relevance(query, docs)
|
| 210 |
+
|
| 211 |
+
for doc, score in zip(docs, relevance_scores):
|
| 212 |
+
doc['relevance'] = score
|
| 213 |
+
|
| 214 |
+
avg_relevance = np.mean(relevance_scores)
|
| 215 |
+
logger.debug(f"CRAG: Initial relevance: {avg_relevance:.3f}")
|
| 216 |
+
|
| 217 |
+
# if relevance sucks, refine and try again
|
| 218 |
+
if avg_relevance < relevance_threshold:
|
| 219 |
+
logger.info("CRAG: Low relevance detected, refining query...")
|
| 220 |
+
|
| 221 |
+
refined_query = self.refine_query(query, [d['text'] for d in docs])
|
| 222 |
+
logger.debug(f"CRAG: Refined query: '{refined_query[:50]}...'")
|
| 223 |
+
|
| 224 |
+
# try again with better query
|
| 225 |
+
refined_docs = self.retrieve(refined_query, top_k=top_k)
|
| 226 |
+
|
| 227 |
+
if refined_docs:
|
| 228 |
+
refined_relevance = self.judge_relevance(refined_query, refined_docs)
|
| 229 |
+
|
| 230 |
+
for doc, score in zip(refined_docs, refined_relevance):
|
| 231 |
+
doc['relevance'] = score
|
| 232 |
+
|
| 233 |
+
refined_avg = np.mean(refined_relevance)
|
| 234 |
+
logger.debug(f"CRAG: Refined relevance: {refined_avg:.3f}")
|
| 235 |
+
|
| 236 |
+
# use refined results only if theyre better
|
| 237 |
+
if refined_avg > avg_relevance:
|
| 238 |
+
docs = refined_docs
|
| 239 |
+
avg_relevance = refined_avg
|
| 240 |
+
logger.info("CRAG: Using refined results")
|
| 241 |
+
else:
|
| 242 |
+
logger.info("CRAG: Keeping original results")
|
| 243 |
+
|
| 244 |
+
# sort by relevance score
|
| 245 |
+
docs.sort(key=lambda x: x['relevance'], reverse=True)
|
| 246 |
+
|
| 247 |
+
return docs, avg_relevance
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
logger.error(f"CRAG failed: {str(e)}")
|
| 251 |
+
return [], 0.0
|
| 252 |
+
|
| 253 |
+
def format_context(self, documents, max_length=500):
|
| 254 |
+
"""Format docs into a string for llm context"""
|
| 255 |
+
try:
|
| 256 |
+
if not documents:
|
| 257 |
+
return "No relevant policies found."
|
| 258 |
+
|
| 259 |
+
context_parts = []
|
| 260 |
+
for i, doc in enumerate(documents):
|
| 261 |
+
text = doc['text'][:max_length]
|
| 262 |
+
relevance = doc.get('relevance', doc.get('score', 0))
|
| 263 |
+
context_parts.append(
|
| 264 |
+
f"Policy {i+1} (relevance: {relevance:.2f}):\n{text}"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
return "\n\n".join(context_parts)
|
| 268 |
+
except Exception as e:
|
| 269 |
+
logger.error(f"Context formatting failed: {str(e)}")
|
| 270 |
+
return "Error formatting policy context."
|
app/services/rule_service.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from openai import OpenAI
|
| 4 |
+
from json_logic import jsonLogic
|
| 5 |
+
from loguru import logger
|
| 6 |
+
from app.constants import MOCK_STORE_SAMPLES
|
| 7 |
+
from app.models import KeyMapping
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RuleGenerationService:
|
| 11 |
+
""" Generates json logic rules using gpt-4o-mini
|
| 12 |
+
Uses self-consistency voting to pick best rule from multiple attempts
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
try:
|
| 17 |
+
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 18 |
+
self.model = "gpt-4o-mini"
|
| 19 |
+
logger.success("RuleGenerationService initialized")
|
| 20 |
+
except Exception as e:
|
| 21 |
+
logger.error(f"Failed to initialize RuleGenerationService: {str(e)}")
|
| 22 |
+
raise
|
| 23 |
+
|
| 24 |
+
def build_system_prompt(self, available_keys, policy_context):
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
keys_str = json.dumps(available_keys, indent=2)
|
| 28 |
+
|
| 29 |
+
system_prompt = f"""You are an expert JSON Logic rule generator for loan application systems.
|
| 30 |
+
|
| 31 |
+
AVAILABLE KEYS (use ONLY these in {{"var": "key"}}):
|
| 32 |
+
{keys_str}
|
| 33 |
+
|
| 34 |
+
POLICY CONTEXT:
|
| 35 |
+
{policy_context}
|
| 36 |
+
|
| 37 |
+
JSON LOGIC OPERATORS:
|
| 38 |
+
- Logical: "and", "or", "!", "if"
|
| 39 |
+
- Comparison: ">", "<", ">=", "<=", "==", "!="
|
| 40 |
+
- Arrays: "in", "some", "all"
|
| 41 |
+
- Math: "+", "-", "*", "/"
|
| 42 |
+
|
| 43 |
+
RULES:
|
| 44 |
+
1. Use ONLY the available keys listed above
|
| 45 |
+
2. All keys must be referenced using {{"var": "key.name"}}
|
| 46 |
+
3. Generate valid JSON Logic syntax
|
| 47 |
+
4. Be precise with thresholds from policies
|
| 48 |
+
5. Use "and" for multiple conditions, "or" for alternatives
|
| 49 |
+
|
| 50 |
+
OUTPUT FORMAT (must be valid JSON):
|
| 51 |
+
{{
|
| 52 |
+
"json_logic": {{"and": [...]}},
|
| 53 |
+
"explanation": "Brief 1-2 sentence explanation",
|
| 54 |
+
"used_keys": ["key1", "key2"],
|
| 55 |
+
"confidence": 0.0-1.0
|
| 56 |
+
}}
|
| 57 |
+
|
| 58 |
+
EXAMPLES:
|
| 59 |
+
|
| 60 |
+
User: "Approve if bureau score > 700"
|
| 61 |
+
Output:
|
| 62 |
+
{{
|
| 63 |
+
"json_logic": {{">": [{{"var": "bureau.score"}}, 700]}},
|
| 64 |
+
"explanation": "Approves applications where bureau score exceeds 700.",
|
| 65 |
+
"used_keys": ["bureau.score"],
|
| 66 |
+
"confidence": 0.95
|
| 67 |
+
}}
|
| 68 |
+
|
| 69 |
+
User: "Reject if wilful default OR suit filed"
|
| 70 |
+
Output:
|
| 71 |
+
{{
|
| 72 |
+
"json_logic": {{"or": [
|
| 73 |
+
{{"==": [{{"var": "bureau.wilful_default"}}, true]}},
|
| 74 |
+
{{"==": [{{"var": "bureau.suit_filed"}}, true]}}
|
| 75 |
+
]}},
|
| 76 |
+
"explanation": "Rejects applications with wilful default or suit filed status.",
|
| 77 |
+
"used_keys": ["bureau.wilful_default", "bureau.suit_filed"],
|
| 78 |
+
"confidence": 0.92
|
| 79 |
+
}}
|
| 80 |
+
|
| 81 |
+
User: "Approve if age between 25 and 60"
|
| 82 |
+
Output:
|
| 83 |
+
{{
|
| 84 |
+
"json_logic": {{"and": [
|
| 85 |
+
{{">=": [{{"var": "primary_applicant.age"}}, 25]}},
|
| 86 |
+
{{"<=": [{{"var": "primary_applicant.age"}}, 60]}}
|
| 87 |
+
]}},
|
| 88 |
+
"explanation": "Approves when primary applicant age is between 25 and 60 years inclusive.",
|
| 89 |
+
"used_keys": ["primary_applicant.age"],
|
| 90 |
+
"confidence": 0.90
|
| 91 |
+
}}"""
|
| 92 |
+
|
| 93 |
+
return system_prompt
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Failed to build system prompt: {str(e)}")
|
| 96 |
+
return ""
|
| 97 |
+
|
| 98 |
+
def generate_single_rule(self, prompt, system_prompt, temperature=0.2):
|
| 99 |
+
# calls llm once to get a rule
|
| 100 |
+
# temperature controls randomness - lower = more deterministic
|
| 101 |
+
try:
|
| 102 |
+
response = self.client.chat.completions.create(
|
| 103 |
+
model=self.model,
|
| 104 |
+
messages=[
|
| 105 |
+
{"role": "system", "content": system_prompt},
|
| 106 |
+
{"role": "user", "content": prompt}
|
| 107 |
+
],
|
| 108 |
+
temperature=temperature,
|
| 109 |
+
max_tokens=800,
|
| 110 |
+
response_format={"type": "json_object"}
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
content = response.choices[0].message.content
|
| 114 |
+
result = json.loads(content)
|
| 115 |
+
|
| 116 |
+
# check if all required fields are there
|
| 117 |
+
if not all(k in result for k in ['json_logic', 'explanation', 'used_keys']):
|
| 118 |
+
logger.warning("LLM response missing required fields")
|
| 119 |
+
raise ValueError("Missing required fields in LLM response")
|
| 120 |
+
|
| 121 |
+
# add default confidence if llm forgot to include it
|
| 122 |
+
if 'confidence' not in result:
|
| 123 |
+
result['confidence'] = 0.8
|
| 124 |
+
|
| 125 |
+
return result
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"Rule generation failed: {str(e)}")
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
def validate_rule(self, rule, available_keys):
|
| 132 |
+
# checks if rule only uses allowed keys
|
| 133 |
+
# extracts all {"var": "..."} and validates them
|
| 134 |
+
try:
|
| 135 |
+
if not rule:
|
| 136 |
+
return False, "Empty rule"
|
| 137 |
+
|
| 138 |
+
# recursively find all var references
|
| 139 |
+
def extract_vars(obj):
|
| 140 |
+
vars_found = []
|
| 141 |
+
if isinstance(obj, dict):
|
| 142 |
+
if "var" in obj:
|
| 143 |
+
vars_found.append(obj["var"])
|
| 144 |
+
for value in obj.values():
|
| 145 |
+
vars_found.extend(extract_vars(value))
|
| 146 |
+
elif isinstance(obj, list):
|
| 147 |
+
for item in obj:
|
| 148 |
+
vars_found.extend(extract_vars(item))
|
| 149 |
+
return vars_found
|
| 150 |
+
|
| 151 |
+
used_vars = extract_vars(rule)
|
| 152 |
+
|
| 153 |
+
# check all vars are in allowed list
|
| 154 |
+
invalid_vars = [v for v in used_vars if v not in available_keys]
|
| 155 |
+
if invalid_vars:
|
| 156 |
+
logger.warning(f"Rule uses invalid keys: {invalid_vars}")
|
| 157 |
+
return False, f"Invalid keys used: {invalid_vars}"
|
| 158 |
+
|
| 159 |
+
return True, ""
|
| 160 |
+
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.error(f"Rule validation failed: {str(e)}")
|
| 163 |
+
return False, str(e)
|
| 164 |
+
|
| 165 |
+
def test_rule_on_mocks(self, rule, num_samples=5):
|
| 166 |
+
# runs the rule against mock data to see if it breaks
|
| 167 |
+
# doesn't check correctness, just that it executes
|
| 168 |
+
try:
|
| 169 |
+
if not rule:
|
| 170 |
+
return 0.0
|
| 171 |
+
|
| 172 |
+
successes = 0
|
| 173 |
+
samples = MOCK_STORE_SAMPLES[:num_samples]
|
| 174 |
+
|
| 175 |
+
for sample in samples:
|
| 176 |
+
try:
|
| 177 |
+
# apply json logic rule
|
| 178 |
+
result = jsonLogic(rule, sample)
|
| 179 |
+
successes += 1
|
| 180 |
+
except Exception as e:
|
| 181 |
+
# rule broke on this sample
|
| 182 |
+
logger.debug(f"Rule test failed on sample: {str(e)}")
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
success_rate = successes / len(samples) if samples else 0.0
|
| 186 |
+
return success_rate
|
| 187 |
+
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.error(f"Mock testing failed: {str(e)}")
|
| 190 |
+
return 0.0
|
| 191 |
+
|
| 192 |
+
def self_consistency_vote(self, variants):
|
| 193 |
+
# picks the best rule from multiple variants
|
| 194 |
+
# scores based on confidence, validation rate, and simplicity
|
| 195 |
+
try:
|
| 196 |
+
if not variants:
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
if len(variants) == 1:
|
| 200 |
+
return variants[0]
|
| 201 |
+
|
| 202 |
+
scored_variants = []
|
| 203 |
+
for variant in variants:
|
| 204 |
+
score = 0.0
|
| 205 |
+
|
| 206 |
+
# llm's own confidence
|
| 207 |
+
score += variant.get('confidence', 0.5) * 0.4
|
| 208 |
+
|
| 209 |
+
# how well it ran on mock data
|
| 210 |
+
validation_rate = variant.get('validation_rate', 0.0)
|
| 211 |
+
score += validation_rate * 0.4
|
| 212 |
+
|
| 213 |
+
# prefer simpler rules (less json length)
|
| 214 |
+
rule_str = json.dumps(variant['json_logic'])
|
| 215 |
+
complexity_penalty = min(len(rule_str) / 500, 0.2)
|
| 216 |
+
score += (0.2 - complexity_penalty)
|
| 217 |
+
|
| 218 |
+
scored_variants.append((score, variant))
|
| 219 |
+
|
| 220 |
+
# sort by score descending
|
| 221 |
+
scored_variants.sort(key=lambda x: x[0], reverse=True)
|
| 222 |
+
|
| 223 |
+
scores_str = [f'{s:.3f}' for s, _ in scored_variants]
|
| 224 |
+
logger.debug(f"Self-consistency scores: {scores_str}")
|
| 225 |
+
|
| 226 |
+
return scored_variants[0][1] # return best one
|
| 227 |
+
|
| 228 |
+
except Exception as e:
|
| 229 |
+
logger.error(f"Self-consistency voting failed: {str(e)}")
|
| 230 |
+
return variants[0] if variants else None
|
| 231 |
+
|
| 232 |
+
def generate_rule(self, prompt, key_mappings, policy_context, num_variants=3):
|
| 233 |
+
"""
|
| 234 |
+
Main method - generates rule with self-consistency
|
| 235 |
+
tries multiple times with different temperatures and picks best
|
| 236 |
+
"""
|
| 237 |
+
try:
|
| 238 |
+
logger.info(f"Generating {num_variants} rule variants...")
|
| 239 |
+
|
| 240 |
+
# get list of allowed keys
|
| 241 |
+
available_keys = [m.mapped_to for m in key_mappings]
|
| 242 |
+
|
| 243 |
+
# build the big system prompt
|
| 244 |
+
system_prompt = self.build_system_prompt(available_keys, policy_context)
|
| 245 |
+
|
| 246 |
+
# generate multiple variants
|
| 247 |
+
variants = []
|
| 248 |
+
temperatures = [0.1, 0.3, 0.5][:num_variants]
|
| 249 |
+
|
| 250 |
+
for i, temp in enumerate(temperatures):
|
| 251 |
+
logger.debug(f"Generating variant {i+1} with temp={temp}...")
|
| 252 |
+
|
| 253 |
+
result = self.generate_single_rule(prompt, system_prompt, temperature=temp)
|
| 254 |
+
|
| 255 |
+
if result:
|
| 256 |
+
# validate it uses correct keys
|
| 257 |
+
is_valid, error_msg = self.validate_rule(
|
| 258 |
+
result['json_logic'],
|
| 259 |
+
available_keys
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
if not is_valid:
|
| 263 |
+
logger.warning(f"Variant {i+1} validation failed: {error_msg}")
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
# test on mock data
|
| 267 |
+
validation_rate = self.test_rule_on_mocks(result['json_logic'])
|
| 268 |
+
result['validation_rate'] = validation_rate
|
| 269 |
+
|
| 270 |
+
logger.info(f"Variant {i+1}: conf={result['confidence']:.3f}, val={validation_rate:.3f}")
|
| 271 |
+
|
| 272 |
+
variants.append(result)
|
| 273 |
+
|
| 274 |
+
if not variants:
|
| 275 |
+
logger.error("Failed to generate any valid rule variants")
|
| 276 |
+
raise ValueError("Failed to generate any valid rule variants")
|
| 277 |
+
|
| 278 |
+
# vote for best rule
|
| 279 |
+
best_rule = self.self_consistency_vote(variants)
|
| 280 |
+
|
| 281 |
+
logger.success(f"Selected best rule")
|
| 282 |
+
|
| 283 |
+
return best_rule
|
| 284 |
+
|
| 285 |
+
except Exception as e:
|
| 286 |
+
logger.error(f"Rule generation failed: {str(e)}")
|
| 287 |
+
raise
|
| 288 |
+
|
| 289 |
+
def calculate_confidence_score(self, rule_result, key_mappings, policy_relevance):
|
| 290 |
+
"""Calculate overall confidence from multiple factors"""
|
| 291 |
+
try:
|
| 292 |
+
# how well keys matched (40%)
|
| 293 |
+
avg_key_sim = sum(m.similarity for m in key_mappings) / len(key_mappings) if key_mappings else 0.0
|
| 294 |
+
|
| 295 |
+
# how relevant policies were (30%)
|
| 296 |
+
policy_score = policy_relevance
|
| 297 |
+
|
| 298 |
+
# llm confidence + validation rate (30%)
|
| 299 |
+
llm_confidence = rule_result.get('confidence', 0.8)
|
| 300 |
+
validation_rate = rule_result.get('validation_rate', 0.8)
|
| 301 |
+
generation_score = (llm_confidence + validation_rate) / 2
|
| 302 |
+
|
| 303 |
+
# weighted average
|
| 304 |
+
confidence = (
|
| 305 |
+
avg_key_sim * 0.4 +
|
| 306 |
+
policy_score * 0.3 +
|
| 307 |
+
generation_score * 0.3
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# clamp to 0-1
|
| 311 |
+
return float(min(max(confidence, 0.0), 1.0))
|
| 312 |
+
|
| 313 |
+
except Exception as e:
|
| 314 |
+
logger.error(f"Confidence calculation failed: {str(e)}")
|
| 315 |
+
return 0.5 # default fallback
|