import re from hashlib import md5 from sentence_transformers import SentenceTransformer from langchain_text_splitters import RecursiveCharacterTextSplitter from transformers import AutoTokenizer # ============================================================================ # LAZY-LOAD EMBEDDING MODEL (load on first use) # ============================================================================ _embed_model = None _flant5tokenizer = None def get_embed_model(): """Load embedding model (lazy-loaded on first use)""" global _embed_model if _embed_model is None: _embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") return _embed_model def get_flant5_tokenizer(): """Load Flan-T5 tokenizer (lazy-loaded on first use)""" global _flant5tokenizer if _flant5tokenizer is None: _flant5tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") return _flant5tokenizer def normalize_text(text): """Normalize text for duplicate detection by removing spaces and punctuation. Returns: str""" if not isinstance(text, str): return "" text = text.lower() # Lowercase text = re.sub(r'\s+', ' ', text) # Remove extra whitespace text = text.strip() # Remove leading/trailing spaces text = re.sub(r'[^\w\s]', '', text) # Remove punctuation return text def chunk_text(text, chunk_size=384, chunk_overlap=20): """Split text into chunks for embeddings. Returns: list(chunks)""" splitter = RecursiveCharacterTextSplitter( separators=["\n\n", "\n", " ", ""], chunk_size=chunk_size, chunk_overlap=chunk_overlap ) return splitter.split_text(text) def create_embeddings(texts): """Create embeddings for a list of texts. Returns: list(embeddings)""" embed_model = get_embed_model() return embed_model.encode( texts, batch_size=64, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True ) def refine_response(answer): """Clean and format generated response text. Returns: str(refined_answer)""" # Replace multiple periods with single period answer = re.sub(r'\. {2,}', '.', answer) answer = re.sub(r'\.([^\s])', r'. \1', answer) # If no ending punctuation, find last one and truncate if not answer.strip().endswith(('.', '!', '?')): last_punc_pos = max(answer.rfind('.'), answer.rfind('!'), answer.rfind('?')) if last_punc_pos != -1: answer = answer[:last_punc_pos + 1] # Capitalize first letter of each sentence sentences = re.split(r'([.!?]\s*)', answer) refined_sentences = [] for i in range(0, len(sentences), 2): sentence_part = sentences[i].strip() if sentence_part: refined_sentences.append(sentence_part.capitalize()) if i + 1 < len(sentences): refined_sentences.append(sentences[i + 1]) return ''.join(refined_sentences).strip() def build_prompt(user_query, context, max_tokens=512): """Build prompt with context and query within token limit. Returns: str(full_prompt)""" flant5tokenizer = get_flant5_tokenizer() if not context: return f"""No relevant medical information found. Q: {user_query} A: Information unavailable.""" instruction_text = "Medical Context:\n" query_footer = f"\nQ: {user_query}\nA:" # Calculate token overhead inst_tokens = len(flant5tokenizer.encode(instruction_text, add_special_tokens=False)) query_tokens = len(flant5tokenizer.encode(query_footer, add_special_tokens=False)) total_static_cost = inst_tokens + query_tokens + 5 # Calculate remaining budget for context remaining_tokens = max_tokens - total_static_cost if remaining_tokens < 0: remaining_tokens = 0 # Fill context budget valid_contexts = [] current_context_tokens = 0 for idx, c in enumerate(context, start=1): chunk_text = f"[C{idx}] {c['question']}\n{c['chunk_answer']}" chunk_len = len(flant5tokenizer.encode(chunk_text, add_special_tokens=False)) if current_context_tokens + chunk_len > remaining_tokens: break valid_contexts.append(chunk_text) current_context_tokens += chunk_len the_context_block = "\n".join(valid_contexts) full_prompt = f"{instruction_text}{the_context_block}{query_footer}" return full_prompt