| import re |
| from hashlib import md5 |
| from sentence_transformers import SentenceTransformer |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from transformers import AutoTokenizer |
|
|
|
|
| |
| |
| |
|
|
| _embed_model = None |
| _flant5tokenizer = None |
|
|
| def get_embed_model(): |
| """Load embedding model (lazy-loaded on first use)""" |
| global _embed_model |
| if _embed_model is None: |
| _embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
| return _embed_model |
|
|
| def get_flant5_tokenizer(): |
| """Load Flan-T5 tokenizer (lazy-loaded on first use)""" |
| global _flant5tokenizer |
| if _flant5tokenizer is None: |
| _flant5tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") |
| return _flant5tokenizer |
|
|
|
|
| def normalize_text(text): |
| """Normalize text for duplicate detection by removing spaces and punctuation. |
| Returns: str""" |
| if not isinstance(text, str): |
| return "" |
| |
| text = text.lower() |
| text = re.sub(r'\s+', ' ', text) |
| text = text.strip() |
| text = re.sub(r'[^\w\s]', '', text) |
| |
| return text |
|
|
|
|
| def chunk_text(text, chunk_size=384, chunk_overlap=20): |
| """Split text into chunks for embeddings. |
| Returns: list(chunks)""" |
| splitter = RecursiveCharacterTextSplitter( |
| separators=["\n\n", "\n", " ", ""], |
| chunk_size=chunk_size, |
| chunk_overlap=chunk_overlap |
| ) |
| return splitter.split_text(text) |
|
|
|
|
| def create_embeddings(texts): |
| """Create embeddings for a list of texts. |
| Returns: list(embeddings)""" |
| embed_model = get_embed_model() |
| return embed_model.encode( |
| texts, |
| batch_size=64, |
| show_progress_bar=False, |
| convert_to_numpy=True, |
| normalize_embeddings=True |
| ) |
|
|
|
|
| def refine_response(answer): |
| """Clean and format generated response text. |
| Returns: str(refined_answer)""" |
| |
| answer = re.sub(r'\. {2,}', '.', answer) |
| answer = re.sub(r'\.([^\s])', r'. \1', answer) |
| |
| |
| if not answer.strip().endswith(('.', '!', '?')): |
| last_punc_pos = max(answer.rfind('.'), answer.rfind('!'), answer.rfind('?')) |
| if last_punc_pos != -1: |
| answer = answer[:last_punc_pos + 1] |
| |
| |
| sentences = re.split(r'([.!?]\s*)', answer) |
| refined_sentences = [] |
| for i in range(0, len(sentences), 2): |
| sentence_part = sentences[i].strip() |
| if sentence_part: |
| refined_sentences.append(sentence_part.capitalize()) |
| if i + 1 < len(sentences): |
| refined_sentences.append(sentences[i + 1]) |
| |
| return ''.join(refined_sentences).strip() |
|
|
|
|
| def build_prompt(user_query, context, max_tokens=512): |
| """Build prompt with context and query within token limit. |
| Returns: str(full_prompt)""" |
| |
| flant5tokenizer = get_flant5_tokenizer() |
| |
| if not context: |
| return f"""No relevant medical information found. |
| Q: {user_query} |
| A: Information unavailable.""" |
| |
| instruction_text = "Medical Context:\n" |
| query_footer = f"\nQ: {user_query}\nA:" |
| |
| |
| inst_tokens = len(flant5tokenizer.encode(instruction_text, add_special_tokens=False)) |
| query_tokens = len(flant5tokenizer.encode(query_footer, add_special_tokens=False)) |
| total_static_cost = inst_tokens + query_tokens + 5 |
| |
| |
| remaining_tokens = max_tokens - total_static_cost |
| if remaining_tokens < 0: |
| remaining_tokens = 0 |
| |
| |
| valid_contexts = [] |
| current_context_tokens = 0 |
| |
| for idx, c in enumerate(context, start=1): |
| chunk_text = f"[C{idx}] {c['question']}\n{c['chunk_answer']}" |
| chunk_len = len(flant5tokenizer.encode(chunk_text, add_special_tokens=False)) |
| |
| if current_context_tokens + chunk_len > remaining_tokens: |
| break |
| |
| valid_contexts.append(chunk_text) |
| current_context_tokens += chunk_len |
| |
| the_context_block = "\n".join(valid_contexts) |
| full_prompt = f"{instruction_text}{the_context_block}{query_footer}" |
| |
| return full_prompt |
|
|