File size: 4,518 Bytes
b30f331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import re
from hashlib import md5
from sentence_transformers import SentenceTransformer
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer


# ============================================================================
# LAZY-LOAD EMBEDDING MODEL (load on first use)
# ============================================================================

_embed_model = None
_flant5tokenizer = None

def get_embed_model():
    """Load embedding model (lazy-loaded on first use)"""
    global _embed_model
    if _embed_model is None:
        _embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    return _embed_model

def get_flant5_tokenizer():
    """Load Flan-T5 tokenizer (lazy-loaded on first use)"""
    global _flant5tokenizer
    if _flant5tokenizer is None:
        _flant5tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
    return _flant5tokenizer


def normalize_text(text):
    """Normalize text for duplicate detection by removing spaces and punctuation.
    Returns: str"""
    if not isinstance(text, str):
        return ""
    
    text = text.lower()  # Lowercase
    text = re.sub(r'\s+', ' ', text)  # Remove extra whitespace
    text = text.strip()  # Remove leading/trailing spaces
    text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
    
    return text


def chunk_text(text, chunk_size=384, chunk_overlap=20):
    """Split text into chunks for embeddings.
    Returns: list(chunks)"""
    splitter = RecursiveCharacterTextSplitter(
        separators=["\n\n", "\n", " ", ""],
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )
    return splitter.split_text(text)


def create_embeddings(texts):
    """Create embeddings for a list of texts.
    Returns: list(embeddings)"""
    embed_model = get_embed_model()
    return embed_model.encode(
        texts,
        batch_size=64,
        show_progress_bar=False,
        convert_to_numpy=True,
        normalize_embeddings=True
    )


def refine_response(answer):
    """Clean and format generated response text.
    Returns: str(refined_answer)"""
    # Replace multiple periods with single period
    answer = re.sub(r'\. {2,}', '.', answer)
    answer = re.sub(r'\.([^\s])', r'. \1', answer)
    
    # If no ending punctuation, find last one and truncate
    if not answer.strip().endswith(('.', '!', '?')):
        last_punc_pos = max(answer.rfind('.'), answer.rfind('!'), answer.rfind('?'))
        if last_punc_pos != -1:
            answer = answer[:last_punc_pos + 1]
    
    # Capitalize first letter of each sentence
    sentences = re.split(r'([.!?]\s*)', answer)
    refined_sentences = []
    for i in range(0, len(sentences), 2):
        sentence_part = sentences[i].strip()
        if sentence_part:
            refined_sentences.append(sentence_part.capitalize())
        if i + 1 < len(sentences):
            refined_sentences.append(sentences[i + 1])
    
    return ''.join(refined_sentences).strip()


def build_prompt(user_query, context, max_tokens=512):
    """Build prompt with context and query within token limit.
    Returns: str(full_prompt)"""
    
    flant5tokenizer = get_flant5_tokenizer()
    
    if not context:
        return f"""No relevant medical information found.
Q: {user_query}
A: Information unavailable."""
    
    instruction_text = "Medical Context:\n"
    query_footer = f"\nQ: {user_query}\nA:"
    
    # Calculate token overhead
    inst_tokens = len(flant5tokenizer.encode(instruction_text, add_special_tokens=False))
    query_tokens = len(flant5tokenizer.encode(query_footer, add_special_tokens=False))
    total_static_cost = inst_tokens + query_tokens + 5
    
    # Calculate remaining budget for context
    remaining_tokens = max_tokens - total_static_cost
    if remaining_tokens < 0:
        remaining_tokens = 0
    
    # Fill context budget
    valid_contexts = []
    current_context_tokens = 0
    
    for idx, c in enumerate(context, start=1):
        chunk_text = f"[C{idx}] {c['question']}\n{c['chunk_answer']}"
        chunk_len = len(flant5tokenizer.encode(chunk_text, add_special_tokens=False))
        
        if current_context_tokens + chunk_len > remaining_tokens:
            break
        
        valid_contexts.append(chunk_text)
        current_context_tokens += chunk_len
    
    the_context_block = "\n".join(valid_contexts)
    full_prompt = f"{instruction_text}{the_context_block}{query_footer}"
    
    return full_prompt