File size: 4,518 Bytes
b30f331 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import re
from hashlib import md5
from sentence_transformers import SentenceTransformer
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
# ============================================================================
# LAZY-LOAD EMBEDDING MODEL (load on first use)
# ============================================================================
_embed_model = None
_flant5tokenizer = None
def get_embed_model():
"""Load embedding model (lazy-loaded on first use)"""
global _embed_model
if _embed_model is None:
_embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
return _embed_model
def get_flant5_tokenizer():
"""Load Flan-T5 tokenizer (lazy-loaded on first use)"""
global _flant5tokenizer
if _flant5tokenizer is None:
_flant5tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
return _flant5tokenizer
def normalize_text(text):
"""Normalize text for duplicate detection by removing spaces and punctuation.
Returns: str"""
if not isinstance(text, str):
return ""
text = text.lower() # Lowercase
text = re.sub(r'\s+', ' ', text) # Remove extra whitespace
text = text.strip() # Remove leading/trailing spaces
text = re.sub(r'[^\w\s]', '', text) # Remove punctuation
return text
def chunk_text(text, chunk_size=384, chunk_overlap=20):
"""Split text into chunks for embeddings.
Returns: list(chunks)"""
splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", " ", ""],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
return splitter.split_text(text)
def create_embeddings(texts):
"""Create embeddings for a list of texts.
Returns: list(embeddings)"""
embed_model = get_embed_model()
return embed_model.encode(
texts,
batch_size=64,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=True
)
def refine_response(answer):
"""Clean and format generated response text.
Returns: str(refined_answer)"""
# Replace multiple periods with single period
answer = re.sub(r'\. {2,}', '.', answer)
answer = re.sub(r'\.([^\s])', r'. \1', answer)
# If no ending punctuation, find last one and truncate
if not answer.strip().endswith(('.', '!', '?')):
last_punc_pos = max(answer.rfind('.'), answer.rfind('!'), answer.rfind('?'))
if last_punc_pos != -1:
answer = answer[:last_punc_pos + 1]
# Capitalize first letter of each sentence
sentences = re.split(r'([.!?]\s*)', answer)
refined_sentences = []
for i in range(0, len(sentences), 2):
sentence_part = sentences[i].strip()
if sentence_part:
refined_sentences.append(sentence_part.capitalize())
if i + 1 < len(sentences):
refined_sentences.append(sentences[i + 1])
return ''.join(refined_sentences).strip()
def build_prompt(user_query, context, max_tokens=512):
"""Build prompt with context and query within token limit.
Returns: str(full_prompt)"""
flant5tokenizer = get_flant5_tokenizer()
if not context:
return f"""No relevant medical information found.
Q: {user_query}
A: Information unavailable."""
instruction_text = "Medical Context:\n"
query_footer = f"\nQ: {user_query}\nA:"
# Calculate token overhead
inst_tokens = len(flant5tokenizer.encode(instruction_text, add_special_tokens=False))
query_tokens = len(flant5tokenizer.encode(query_footer, add_special_tokens=False))
total_static_cost = inst_tokens + query_tokens + 5
# Calculate remaining budget for context
remaining_tokens = max_tokens - total_static_cost
if remaining_tokens < 0:
remaining_tokens = 0
# Fill context budget
valid_contexts = []
current_context_tokens = 0
for idx, c in enumerate(context, start=1):
chunk_text = f"[C{idx}] {c['question']}\n{c['chunk_answer']}"
chunk_len = len(flant5tokenizer.encode(chunk_text, add_special_tokens=False))
if current_context_tokens + chunk_len > remaining_tokens:
break
valid_contexts.append(chunk_text)
current_context_tokens += chunk_len
the_context_block = "\n".join(valid_contexts)
full_prompt = f"{instruction_text}{the_context_block}{query_footer}"
return full_prompt
|