rag / utils.py
amiraghhh's picture
Update utils.py
ab46c35 verified
import re
from hashlib import md5
from sentence_transformers import SentenceTransformer
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
# ============================================================================
# LAZY-LOAD EMBEDDING MODEL (load on first use)
# ============================================================================
_embed_model = None
_flant5tokenizer = None
def get_embed_model():
"""Load embedding model (lazy-loaded on first use)"""
global _embed_model
if _embed_model is None:
_embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
return _embed_model
def get_flant5_tokenizer():
"""Load Flan-T5 tokenizer (lazy-loaded on first use)"""
global _flant5tokenizer
if _flant5tokenizer is None:
_flant5tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
return _flant5tokenizer
def normalize_text(text):
"""Normalize text for duplicate detection by removing spaces and punctuation.
Returns: str"""
if not isinstance(text, str):
return ""
text = text.lower() # Lowercase
text = re.sub(r'\s+', ' ', text) # Remove extra whitespace
text = text.strip() # Remove leading/trailing spaces
text = re.sub(r'[^\w\s]', '', text) # Remove punctuation
return text
def chunk_text(text, chunk_size=384, chunk_overlap=20):
"""Split text into chunks for embeddings.
Returns: list(chunks)"""
splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", " ", ""],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
return splitter.split_text(text)
def create_embeddings(texts):
"""Create embeddings for a list of texts.
Returns: list(embeddings)"""
embed_model = get_embed_model()
return embed_model.encode(
texts,
batch_size=64,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=True
)
def refine_response(answer):
"""Clean and format generated response text.
Returns: str(refined_answer)"""
# Replace multiple periods with single period
answer = re.sub(r'\. {2,}', '.', answer)
answer = re.sub(r'\.([^\s])', r'. \1', answer)
# If no ending punctuation, find last one and truncate
if not answer.strip().endswith(('.', '!', '?')):
last_punc_pos = max(answer.rfind('.'), answer.rfind('!'), answer.rfind('?'))
if last_punc_pos != -1:
answer = answer[:last_punc_pos + 1]
# Capitalize first letter of each sentence
sentences = re.split(r'([.!?]\s*)', answer)
refined_sentences = []
for i in range(0, len(sentences), 2):
sentence_part = sentences[i].strip()
if sentence_part:
refined_sentences.append(sentence_part.capitalize())
if i + 1 < len(sentences):
refined_sentences.append(sentences[i + 1])
return ''.join(refined_sentences).strip()
def build_prompt(user_query, context, max_tokens=512):
"""Build prompt with context and query within token limit.
Returns: str(full_prompt)"""
flant5tokenizer = get_flant5_tokenizer()
if not context:
return f"""No relevant medical information found.
Q: {user_query}
A: Information unavailable."""
instruction_text = "Medical Context:\n"
query_footer = f"\nQ: {user_query}\nA:"
# Calculate token overhead
inst_tokens = len(flant5tokenizer.encode(instruction_text, add_special_tokens=False))
query_tokens = len(flant5tokenizer.encode(query_footer, add_special_tokens=False))
total_static_cost = inst_tokens + query_tokens + 5
# Calculate remaining budget for context
remaining_tokens = max_tokens - total_static_cost
if remaining_tokens < 0:
remaining_tokens = 0
# Fill context budget
valid_contexts = []
current_context_tokens = 0
for idx, c in enumerate(context, start=1):
chunk_text = f"[C{idx}] {c['question']}\n{c['chunk_answer']}"
chunk_len = len(flant5tokenizer.encode(chunk_text, add_special_tokens=False))
if current_context_tokens + chunk_len > remaining_tokens:
break
valid_contexts.append(chunk_text)
current_context_tokens += chunk_len
the_context_block = "\n".join(valid_contexts)
full_prompt = f"{instruction_text}{the_context_block}{query_footer}"
return full_prompt