amiraghhh commited on
Commit
a09c868
·
verified ·
1 Parent(s): ed87cc6

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +13 -0
  2. model.py +130 -0
  3. requirements.txt +9 -0
  4. retriever.py +163 -0
  5. utils.py +114 -0
app.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model import rag
3
+
4
+ demo = gr.Interface(
5
+ fn = rag,
6
+ inputs = gr.Textbox(lines=5, label="Enter your medical question here"),
7
+ outputs = gr.Textbox(lines=10, label="Generated Answer"),
8
+ title = "Medical QA",
9
+ description = "Type your query and get a model generated response.",
10
+ theme = gr.themes.Glass()
11
+ )
12
+
13
+ demo.launch()
model.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ import traceback
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
+ from retriever import retrieve
5
+ from utils import build_prompt, refine_response
6
+
7
+
8
+ # ============================================================================
9
+ # LOAD MODELS AND VECTOR STORE ONCE AT STARTUP
10
+ # ============================================================================
11
+
12
+ # Load vector store
13
+ db_client = chromadb.PersistentClient(path="./MedQuAD_db")
14
+ try:
15
+ vector_store = db_client.get_collection("medical_rag")
16
+ except:
17
+ # If collection doesn't exist, create it
18
+ vector_store = db_client.create_collection(name="medical_rag")
19
+
20
+ # Load fine-tuned model for generation
21
+ ft_model_id = "amiraghhh/fine-tuned-flan-t5-small"
22
+ ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_id)
23
+ ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_id)
24
+
25
+ finetuned_llm = pipeline(
26
+ "text2text-generation",
27
+ model=ft_model,
28
+ tokenizer=ft_tokenizer,
29
+ decoder_start_token_id=ft_model.config.pad_token_id
30
+ )
31
+
32
+ # Load base model config for EOS token
33
+ base_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
34
+
35
+
36
+ # ============================================================================
37
+ # MAIN RAG FUNCTION
38
+ # ============================================================================
39
+
40
+ def rag(user_query):
41
+ """Main RAG function: retrieve context and generate answer.
42
+ Takes a question string and returns an answer string with confidence.
43
+ Returns: str(generated_answer)"""
44
+
45
+ try:
46
+ # 1. Check for emergency keywords
47
+ emergency_keywords = ["emergency", "severe pain", "bleeding",
48
+ "blind", "lose consciousness", "pass out"]
49
+
50
+ if any(keyword in user_query.lower() for keyword in emergency_keywords):
51
+ emergency_msg = """I am an AI and cannot provide medical advice for emergencies.
52
+ PLEASE contact emergency services or a medical professional immediately."""
53
+
54
+ try:
55
+ # Still generate answer for context
56
+ contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True)
57
+
58
+ if not contexts:
59
+ return f"{emergency_msg}\n\nNo relevant information found for your query."
60
+
61
+ prompt = build_prompt(user_query, contexts)
62
+ result = finetuned_llm(
63
+ prompt,
64
+ max_new_tokens=70,
65
+ num_beams=3,
66
+ early_stopping=True,
67
+ do_sample=False,
68
+ repetition_penalty=1.4,
69
+ eos_token_id=finetuned_llm.tokenizer.eos_token_id
70
+ )
71
+
72
+ answer = result[0]['generated_text'].strip()
73
+ answer = refine_response(answer)
74
+
75
+ # Calculate confidence
76
+ if contexts:
77
+ avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts)
78
+ confidence_score = (1 - avg_distance) * 100
79
+ confidence_score = max(0, min(100, confidence_score))
80
+ else:
81
+ confidence_score = 0
82
+
83
+ return f"{emergency_msg}\n\n[Confidence: {confidence_score:.1f}%]\n\n{answer}"
84
+
85
+ except Exception as e:
86
+ return f"{emergency_msg}\n\nError generating answer: {str(e)}"
87
+
88
+ # 2. Retrieve relevant contexts
89
+ contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True)
90
+
91
+ if not contexts:
92
+ return "I'm not confident about my answer (0%).\n\nCouldn't find relevant information to answer your question."
93
+
94
+ # 3. Build prompt with context
95
+ prompt = build_prompt(user_query, contexts)
96
+
97
+ # 4. Generate answer
98
+ result = finetuned_llm(
99
+ prompt,
100
+ max_new_tokens=70,
101
+ num_beams=3,
102
+ early_stopping=True,
103
+ do_sample=False,
104
+ repetition_penalty=1.4,
105
+ eos_token_id=finetuned_llm.tokenizer.eos_token_id
106
+ )
107
+
108
+ answer = result[0]['generated_text'].strip()
109
+ answer = refine_response(answer)
110
+
111
+ # 5. Calculate confidence score based on retrieval quality
112
+ if contexts and len(contexts) > 0:
113
+ avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts)
114
+ confidence_score = (1 - avg_distance) * 100
115
+ confidence_score = max(0, min(100, confidence_score))
116
+
117
+ # Build final response with confidence
118
+ if confidence_score < 40:
119
+ final_response = f"I'm not confident about my answer ({confidence_score:.1f}%).\n\n{answer}"
120
+ else:
121
+ final_response = f"{answer}\n\n[Confidence: {confidence_score:.1f}%]"
122
+ else:
123
+ final_response = "I'm not confident about my answer (0%).\n\n" + answer
124
+
125
+ return final_response
126
+
127
+ except Exception as e:
128
+ error_msg = f"ERROR in RAG pipeline: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
129
+ print(error_msg)
130
+ return error_msg
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.26.0
2
+ torch==2.1.2
3
+ transformers==4.37.2
4
+ sentence-transformers==2.2.2
5
+ chromadb==0.4.24
6
+ langchain-text-splitters==0.0.1
7
+ accelerate==0.27.2
8
+ numpy==1.24.3
9
+ pandas==2.0.3
retriever.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ from utils import embed_model
4
+
5
+
6
+ # Initialize reranker model for relevance scoring
7
+ rerank_tokenizer = AutoTokenizer.from_pretrained("castorini/monot5-base-msmarco")
8
+ rerank_model = AutoModelForSeq2SeqLM.from_pretrained("castorini/monot5-base-msmarco")
9
+ rerank_model.eval()
10
+
11
+ # Initialize query rewriter
12
+ rewritter_llm = pipeline(
13
+ "text2text-generation",
14
+ model="google/flan-t5-small",
15
+ max_length=64,
16
+ do_sample=False,
17
+ temperature=0.3,
18
+ repetition_penalty=1.3,
19
+ no_repeat_ngram_size=2
20
+ )
21
+
22
+
23
+ def rewrite_query(user_query):
24
+ """Rewrite user query to be more specific and medical-focused.
25
+ Returns: str(rewritten_query)"""
26
+
27
+ prompt = f"""Rewrite the input into a clear medical question following these patterns
28
+ Input: my head hurts
29
+ Output: What causes headaches?
30
+
31
+ Input: i keep vomiting but feel ok afterwards
32
+ Output: What causes cyclic vomiting?
33
+
34
+ Input: chest pain when breathing
35
+ Output: What causes chest pain during breathing?
36
+
37
+ Input: {user_query}
38
+ Output:
39
+ """
40
+
41
+ llm_output = rewritter_llm(prompt)
42
+ rewritten_query = llm_output[0]['generated_text']
43
+ rewritten_query = rewritten_query.replace("Output:", "").strip()
44
+
45
+ return rewritten_query.strip()
46
+
47
+
48
+ def get_monot5_scores(rewritten_query, chunks):
49
+ """Calculate relevance scores for chunks against query using MonoT5 reranker.
50
+ Returns: list(scores)"""
51
+
52
+ # Format input as "Query: ABC? Document: XYZ... Relevant:"
53
+ inputs = [f"Query: {rewritten_query} Document: {c} Relevant:" for c in chunks]
54
+
55
+ # Tokenize
56
+ tokenized_inputs = rerank_tokenizer(
57
+ inputs,
58
+ padding=True,
59
+ truncation=True,
60
+ return_tensors='pt'
61
+ )
62
+
63
+ # Generate predictions
64
+ with torch.no_grad():
65
+ outputs = rerank_model.generate(
66
+ input_ids=tokenized_inputs['input_ids'],
67
+ attention_mask=tokenized_inputs['attention_mask'],
68
+ max_new_tokens=1,
69
+ return_dict_in_generate=True,
70
+ output_scores=True
71
+ )
72
+
73
+ # Extract "true" token probability scores
74
+ true_token_id = rerank_tokenizer.encode("true")[0]
75
+ batch_scores = outputs.scores[0][:, true_token_id]
76
+
77
+ return batch_scores.tolist()
78
+
79
+
80
+ def retrieve(vector_store, query, top_k=3, use_reranking=True, detail=False):
81
+ """Retrieve relevant context chunks for a query with optional reranking.
82
+ Returns: list(contexts)"""
83
+
84
+ is_single_query = isinstance(query, str)
85
+ queries = [query] if is_single_query else query
86
+
87
+ # Rewrite queries for better matching
88
+ rewritten_queries = [rewrite_query(q) for q in queries]
89
+
90
+ # Embed rewritten queries
91
+ q_embeddings = embed_model.encode(rewritten_queries).tolist()
92
+
93
+ # Vector search
94
+ search_results = vector_store.query(
95
+ query_embeddings=q_embeddings,
96
+ n_results=10
97
+ )
98
+
99
+ all_contexts = []
100
+
101
+ for i in range(len(queries)):
102
+ contexts_for_query = []
103
+
104
+ if not search_results['documents'][i]:
105
+ all_contexts.append([])
106
+ continue
107
+
108
+ retrieved_chunks = search_results['documents'][i]
109
+ retrieved_metas = search_results['metadatas'][i]
110
+ retrieved_ids = search_results['ids'][i]
111
+ retrieved_distances = search_results.get('distances', [[]])[i]
112
+
113
+ # Rerank if enabled
114
+ if use_reranking:
115
+ rerank_scores = get_monot5_scores(rewritten_queries[i], retrieved_chunks)
116
+
117
+ candidates = []
118
+ for r in range(len(retrieved_chunks)):
119
+ candidates.append({
120
+ 'chunk_id': retrieved_ids[r],
121
+ 'chunk_answer': retrieved_chunks[r],
122
+ 'question': retrieved_metas[r]['question'],
123
+ 'focus_area': retrieved_metas[r].get('focus_area', 'Unknown'),
124
+ 'chroma_dist': retrieved_distances[r],
125
+ 'rerank_score': rerank_scores[r]
126
+ })
127
+
128
+ # Sort by rerank score
129
+ candidates.sort(key=lambda s: s['rerank_score'], reverse=True)
130
+ sorted_results = candidates[:top_k]
131
+ else:
132
+ # Just use top-k from vector search
133
+ sorted_results = [
134
+ {
135
+ 'chunk_id': retrieved_ids[r],
136
+ 'chunk_answer': retrieved_chunks[r],
137
+ 'question': retrieved_metas[r]['question'],
138
+ 'focus_area': retrieved_metas[r].get('focus_area', 'Unknown'),
139
+ 'chroma_dist': retrieved_distances[r]
140
+ }
141
+ for r in range(min(top_k, len(retrieved_chunks)))
142
+ ]
143
+
144
+ # Format output
145
+ for item_dict in sorted_results:
146
+ item = {
147
+ 'chunk_id': item_dict['chunk_id'],
148
+ 'chunk_answer': item_dict['chunk_answer'],
149
+ 'question': item_dict['question']
150
+ }
151
+ if detail:
152
+ item['focus_area'] = item_dict['focus_area']
153
+ item['chroma_distance'] = round(item_dict['chroma_dist'], 3)
154
+ if 'rerank_score' in item_dict:
155
+ item['rerank_score'] = round(item_dict['rerank_score'], 3)
156
+ contexts_for_query.append(item)
157
+
158
+ all_contexts.append(contexts_for_query)
159
+
160
+ if is_single_query:
161
+ return all_contexts[0]
162
+ else:
163
+ return all_contexts
utils.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from hashlib import md5
3
+ from sentence_transformers import SentenceTransformer
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ from transformers import AutoTokenizer
6
+
7
+ # Initialize embedding model (used across multiple modules)
8
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
9
+ flant5tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
10
+
11
+
12
+ def normalize_text(text):
13
+ """Normalize text for duplicate detection by removing spaces and punctuation.
14
+ Returns: str"""
15
+ if not isinstance(text, str):
16
+ return ""
17
+
18
+ text = text.lower() # Lowercase
19
+ text = re.sub(r'\s+', ' ', text) # Remove extra whitespace
20
+ text = text.strip() # Remove leading/trailing spaces
21
+ text = re.sub(r'[^\w\s]', '', text) # Remove punctuation
22
+
23
+ return text
24
+
25
+
26
+ def chunk_text(text, chunk_size=384, chunk_overlap=20):
27
+ """Split text into chunks for embeddings.
28
+ Returns: list(chunks)"""
29
+ splitter = RecursiveCharacterTextSplitter(
30
+ separators=["\n\n", "\n", " ", ""],
31
+ chunk_size=chunk_size,
32
+ chunk_overlap=chunk_overlap
33
+ )
34
+ return splitter.split_text(text)
35
+
36
+
37
+ def create_embeddings(texts):
38
+ """Create embeddings for a list of texts.
39
+ Returns: list(embeddings)"""
40
+ return embed_model.encode(
41
+ texts,
42
+ batch_size=64,
43
+ show_progress_bar=False,
44
+ convert_to_numpy=True,
45
+ normalize_embeddings=True
46
+ )
47
+
48
+
49
+ def refine_response(answer):
50
+ """Clean and format generated response text.
51
+ Returns: str(refined_answer)"""
52
+ # Replace multiple periods with single period
53
+ answer = re.sub(r'\. {2,}', '.', answer)
54
+ answer = re.sub(r'\.([^\s])', r'. \1', answer)
55
+
56
+ # If no ending punctuation, find last one and truncate
57
+ if not answer.strip().endswith(('.', '!', '?')):
58
+ last_punc_pos = max(answer.rfind('.'), answer.rfind('!'), answer.rfind('?'))
59
+ if last_punc_pos != -1:
60
+ answer = answer[:last_punc_pos + 1]
61
+
62
+ # Capitalize first letter of each sentence
63
+ sentences = re.split(r'([.!?]\s*)', answer)
64
+ refined_sentences = []
65
+ for i in range(0, len(sentences), 2):
66
+ sentence_part = sentences[i].strip()
67
+ if sentence_part:
68
+ refined_sentences.append(sentence_part.capitalize())
69
+ if i + 1 < len(sentences):
70
+ refined_sentences.append(sentences[i + 1])
71
+
72
+ return ''.join(refined_sentences).strip()
73
+
74
+
75
+ def build_prompt(user_query, context, max_tokens=512):
76
+ """Build prompt with context and query within token limit.
77
+ Returns: str(full_prompt)"""
78
+
79
+ if not context:
80
+ return f"""No relevant medical information found.
81
+ Q: {user_query}
82
+ A: Information unavailable."""
83
+
84
+ instruction_text = "Medical Context:\n"
85
+ query_footer = f"\nQ: {user_query}\nA:"
86
+
87
+ # Calculate token overhead
88
+ inst_tokens = len(flant5tokenizer.encode(instruction_text, add_special_tokens=False))
89
+ query_tokens = len(flant5tokenizer.encode(query_footer, add_special_tokens=False))
90
+ total_static_cost = inst_tokens + query_tokens + 5
91
+
92
+ # Calculate remaining budget for context
93
+ remaining_tokens = max_tokens - total_static_cost
94
+ if remaining_tokens < 0:
95
+ remaining_tokens = 0
96
+
97
+ # Fill context budget
98
+ valid_contexts = []
99
+ current_context_tokens = 0
100
+
101
+ for idx, c in enumerate(context, start=1):
102
+ chunk_text = f"[C{idx}] {c['question']}\n{c['chunk_answer']}"
103
+ chunk_len = len(flant5tokenizer.encode(chunk_text, add_special_tokens=False))
104
+
105
+ if current_context_tokens + chunk_len > remaining_tokens:
106
+ break
107
+
108
+ valid_contexts.append(chunk_text)
109
+ current_context_tokens += chunk_len
110
+
111
+ the_context_block = "\n".join(valid_contexts)
112
+ full_prompt = f"{instruction_text}{the_context_block}{query_footer}"
113
+
114
+ return full_prompt