|
|
import re |
|
|
from hashlib import md5 |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_embed_model = None |
|
|
_flant5tokenizer = None |
|
|
|
|
|
def get_embed_model(): |
|
|
"""Load embedding model (lazy-loaded on first use)""" |
|
|
global _embed_model |
|
|
if _embed_model is None: |
|
|
_embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
return _embed_model |
|
|
|
|
|
def get_flant5_tokenizer(): |
|
|
"""Load Flan-T5 tokenizer (lazy-loaded on first use)""" |
|
|
global _flant5tokenizer |
|
|
if _flant5tokenizer is None: |
|
|
_flant5tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") |
|
|
return _flant5tokenizer |
|
|
|
|
|
|
|
|
def normalize_text(text): |
|
|
"""Normalize text for duplicate detection by removing spaces and punctuation. |
|
|
Returns: str""" |
|
|
if not isinstance(text, str): |
|
|
return "" |
|
|
|
|
|
text = text.lower() |
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
text = text.strip() |
|
|
text = re.sub(r'[^\w\s]', '', text) |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
def chunk_text(text, chunk_size=384, chunk_overlap=20): |
|
|
"""Split text into chunks for embeddings. |
|
|
Returns: list(chunks)""" |
|
|
splitter = RecursiveCharacterTextSplitter( |
|
|
separators=["\n\n", "\n", " ", ""], |
|
|
chunk_size=chunk_size, |
|
|
chunk_overlap=chunk_overlap |
|
|
) |
|
|
return splitter.split_text(text) |
|
|
|
|
|
|
|
|
def create_embeddings(texts): |
|
|
"""Create embeddings for a list of texts. |
|
|
Returns: list(embeddings)""" |
|
|
embed_model = get_embed_model() |
|
|
return embed_model.encode( |
|
|
texts, |
|
|
batch_size=64, |
|
|
show_progress_bar=False, |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True |
|
|
) |
|
|
|
|
|
|
|
|
def refine_response(answer): |
|
|
"""Clean and format generated response text. |
|
|
Returns: str(refined_answer)""" |
|
|
|
|
|
answer = re.sub(r'\. {2,}', '.', answer) |
|
|
answer = re.sub(r'\.([^\s])', r'. \1', answer) |
|
|
|
|
|
|
|
|
if not answer.strip().endswith(('.', '!', '?')): |
|
|
last_punc_pos = max(answer.rfind('.'), answer.rfind('!'), answer.rfind('?')) |
|
|
if last_punc_pos != -1: |
|
|
answer = answer[:last_punc_pos + 1] |
|
|
|
|
|
|
|
|
sentences = re.split(r'([.!?]\s*)', answer) |
|
|
refined_sentences = [] |
|
|
for i in range(0, len(sentences), 2): |
|
|
sentence_part = sentences[i].strip() |
|
|
if sentence_part: |
|
|
refined_sentences.append(sentence_part.capitalize()) |
|
|
if i + 1 < len(sentences): |
|
|
refined_sentences.append(sentences[i + 1]) |
|
|
|
|
|
return ''.join(refined_sentences).strip() |
|
|
|
|
|
|
|
|
def build_prompt(user_query, context, max_tokens=512): |
|
|
"""Build prompt with context and query within token limit. |
|
|
Returns: str(full_prompt)""" |
|
|
|
|
|
flant5tokenizer = get_flant5_tokenizer() |
|
|
|
|
|
if not context: |
|
|
return f"""No relevant medical information found. |
|
|
Q: {user_query} |
|
|
A: Information unavailable.""" |
|
|
|
|
|
instruction_text = "Medical Context:\n" |
|
|
query_footer = f"\nQ: {user_query}\nA:" |
|
|
|
|
|
|
|
|
inst_tokens = len(flant5tokenizer.encode(instruction_text, add_special_tokens=False)) |
|
|
query_tokens = len(flant5tokenizer.encode(query_footer, add_special_tokens=False)) |
|
|
total_static_cost = inst_tokens + query_tokens + 5 |
|
|
|
|
|
|
|
|
remaining_tokens = max_tokens - total_static_cost |
|
|
if remaining_tokens < 0: |
|
|
remaining_tokens = 0 |
|
|
|
|
|
|
|
|
valid_contexts = [] |
|
|
current_context_tokens = 0 |
|
|
|
|
|
for idx, c in enumerate(context, start=1): |
|
|
chunk_text = f"[C{idx}] {c['question']}\n{c['chunk_answer']}" |
|
|
chunk_len = len(flant5tokenizer.encode(chunk_text, add_special_tokens=False)) |
|
|
|
|
|
if current_context_tokens + chunk_len > remaining_tokens: |
|
|
break |
|
|
|
|
|
valid_contexts.append(chunk_text) |
|
|
current_context_tokens += chunk_len |
|
|
|
|
|
the_context_block = "\n".join(valid_contexts) |
|
|
full_prompt = f"{instruction_text}{the_context_block}{query_footer}" |
|
|
|
|
|
return full_prompt |
|
|
|