abnuel commited on
Commit
1d53ee2
·
verified ·
1 Parent(s): 84bc192

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -206
app.py CHANGED
@@ -1,206 +1,206 @@
1
- import pandas as pd
2
- import torch
3
- import os
4
- from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
5
- from huggingface_hub import login
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain.embeddings import SentenceTransformerEmbeddings
8
- from langchain.vectorstores import FAISS
9
- from langchain.chains import RetrievalQA
10
- from langchain.prompts import PromptTemplate
11
- from langchain.llms import HuggingFacePipeline
12
- from langchain_community.document_loaders.csv_loader import CSVLoader
13
- import transformers
14
- from langchain.schema import Document
15
- import gradio as gr
16
- import re
17
-
18
-
19
- model = "abnuel/MedGemma-4b-ICD"
20
-
21
- #tokenizer = AutoTokenizer.from_pretrained("abnuel/MedGemma-4b-ICD")
22
-
23
- SYSTEM_PROMPT = "You are an expert medical coder. Your task is to analyze the clinical description provided and output only the single, most appropriate ICD-10-CM code. Do not include any text, justification other than the code itself."
24
-
25
- bnb_config = BitsAndBytesConfig(
26
- load_in_4bit=True,
27
- bnb_4bit_quant_type="nf4",
28
- bnb_4bit_use_double_quant=True,
29
- bnb_4bit_compute_dtype=torch.bfloat16
30
- )
31
- tokenizer = AutoTokenizer.from_pretrained(model)
32
- if tokenizer.pad_token is None:
33
- tokenizer.pad_token = tokenizer.eos_token
34
-
35
- model = AutoModelForCausalLM.from_pretrained(
36
- model,
37
- quantization_config=bnb_config,
38
- device_map="auto"
39
- )
40
-
41
-
42
- def generate_response(clinical_note):
43
- messages = [
44
- {"role": "system", "content": SYSTEM_PROMPT},
45
- {"role": "user", "content": f"Code the following: {clinical_note}"},
46
- ]
47
-
48
- # 3. Apply chat template and tokenize
49
-
50
- inputs = tokenizer.apply_chat_template(
51
- messages,
52
- add_generation_prompt=True,
53
- tokenize=True,
54
- return_dict=True,
55
- return_tensors="pt"
56
- ).to(model.device)
57
-
58
- input_len = inputs["input_ids"].shape[-1]
59
-
60
- # 4. Generate the response
61
- with torch.inference_mode():
62
- generation = model.generate(
63
- **inputs,
64
- max_new_tokens=200, # Max length of the generated ICD codes
65
- do_sample=False, # Use greedy decoding for predictable output
66
- temperature=0.0, # Zero temperature for deterministic results
67
- )
68
-
69
- # 5. Decode the output
70
- # Extract only the newly generated tokens
71
- generation = generation[0][input_len:]
72
- decoded_output = tokenizer.decode(generation, skip_special_tokens=True)
73
-
74
- return decoded_output.strip()
75
-
76
- # --- Example Usage ---
77
- #test_note = "Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."
78
-
79
-
80
- #print(f"Clinical Note: {test_note}")
81
- #response = generate_response(test_note)
82
- #print(f"Generated ICD Codes: {response}")
83
-
84
-
85
-
86
- pipe = transformers.pipeline(
87
- "text-generation",
88
- model=model,
89
- tokenizer=tokenizer,
90
- max_new_tokens=50,
91
- temperature=0.1,
92
- do_sample=False,
93
- pad_token_id=tokenizer.eos_token_id,
94
-
95
- )
96
- hf_llm = HuggingFacePipeline(pipeline=pipe)
97
-
98
-
99
- df = pd.read_csv("medical_coding_train_1.csv")
100
-
101
- documents = [
102
- Document(
103
- page_content=f"note: {row['note']}\nicd_code: {row['icd_codes']}",
104
- metadata={"icd_code": row["icd_codes"]}
105
- )
106
- for _, row in df.iterrows()
107
- ]
108
-
109
-
110
- # 2. Chunk Documents
111
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
112
- docs = text_splitter.split_documents(documents)
113
-
114
- # 3. Create Embeddings and Vector Store (FAISS)
115
- embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
116
- db = FAISS.from_documents(docs, embeddings)
117
- retriever = db.as_retriever(search_kwargs={"k": 2})
118
-
119
-
120
- RAG_PROMPT_TEMPLATE = """
121
- You are an expert medical coder.
122
- Your task is to determine the most accurate ICD-10-CM code for the given clinical note.
123
-
124
- Use ONLY the following context (which may include ICD codes from similar cases).
125
- If you cannot determine a match from the provided context, respond exactly with:
126
- "I cannot find the code in the provided documents."
127
-
128
- Return ONLY the ICD-10-CM code itself — no explanation, no text, no punctuation.
129
-
130
- Context:
131
- {context}
132
-
133
- Clinical Note:
134
- {question}
135
-
136
- ICD-10-CM Code:
137
- """
138
-
139
-
140
- rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
141
- #direct_chain = LLMChain(llm=hf_llm, prompt=rag_prompt)
142
-
143
-
144
- # 5. Create the QA Chain
145
- qa_chain = RetrievalQA.from_chain_type(
146
- llm=hf_llm,
147
- chain_type="stuff",
148
- retriever=retriever,
149
- return_source_documents=False,
150
- chain_type_kwargs={"prompt": rag_prompt}
151
- )
152
-
153
- def extract_icd_code(text):
154
- # Pattern to match "ICD-10-CM Code:" followed by the code
155
- pattern = r'ICD-10-CM Code:\s*([A-Z0-9.]+)'
156
-
157
- match = re.search(pattern, text)
158
-
159
- if match:
160
- return match.group(1)
161
- return None
162
-
163
- def generate_code_rag(clinical_note, retriever, threshold=0.35):
164
- """Generates the ICD code using RAG."""
165
-
166
- # Format the user question for the RAG prompt template
167
- query = f"Code the following: {clinical_note}"
168
-
169
- # Step 1: Retrieve docs
170
- docs_and_scores = db.similarity_search_with_score(query, k=2)
171
-
172
- # Step 2: Filter by similarity threshold
173
- relevant_docs = [doc for doc, score in docs_and_scores if score > threshold]
174
-
175
-
176
-
177
- if relevant_docs:
178
- #print(qa_chain)
179
- result = qa_chain({"query": query})["result"]
180
-
181
- #answer = result['result']
182
- icd_code = extract_icd_code(result)
183
- #print(icd_code)
184
- if icd_code == None:
185
- print("I got here")
186
- result = generate_response(clinical_note)
187
- return result
188
- else:
189
- return icd_code
190
-
191
- # Step 4: Otherwise, use LLM directly (no context)
192
-
193
-
194
- # Create the Gradio Interface
195
- gr.Interface(
196
- fn=generate_code_rag,
197
- inputs=gr.Textbox(lines=5, label="Enter Clinical Note Here", placeholder="e.g., Patient presented with simple laceration of the left hand."),
198
- outputs=gr.Textbox(label="Predicted ICD-10 Code"),
199
- title="ClaimSwift Medical Coding",
200
- description="",
201
- examples=[
202
- ["Benign neoplasm of peripheral nerves and autonomic nervous system of face, head, and neck"],
203
- ["Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."],
204
- ["Simple laceration of the left hand without foreign body."],
205
- ]
206
- ).launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import pandas as pd
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
5
+ from huggingface_hub import login
6
+ from langchain_text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.embeddings import SentenceTransformerEmbeddings
8
+ from langchain.vectorstores import FAISS
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.llms import HuggingFacePipeline
12
+ from langchain_community.document_loaders.csv_loader import CSVLoader
13
+ import transformers
14
+ from langchain.schema import Document
15
+ import gradio as gr
16
+ import re
17
+
18
+
19
+ model = "abnuel/MedGemma-4b-ICD"
20
+
21
+ #tokenizer = AutoTokenizer.from_pretrained("abnuel/MedGemma-4b-ICD")
22
+
23
+ SYSTEM_PROMPT = "You are an expert medical coder. Your task is to analyze the clinical description provided and output only the single, most appropriate ICD-10-CM code. Do not include any text, justification other than the code itself."
24
+
25
+ bnb_config = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_quant_type="nf4",
28
+ bnb_4bit_use_double_quant=True,
29
+ bnb_4bit_compute_dtype=torch.bfloat16
30
+ )
31
+ tokenizer = AutoTokenizer.from_pretrained(model)
32
+ if tokenizer.pad_token is None:
33
+ tokenizer.pad_token = tokenizer.eos_token
34
+
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model,
37
+ quantization_config=bnb_config,
38
+ device_map="auto"
39
+ )
40
+
41
+
42
+ def generate_response(clinical_note):
43
+ messages = [
44
+ {"role": "system", "content": SYSTEM_PROMPT},
45
+ {"role": "user", "content": f"Code the following: {clinical_note}"},
46
+ ]
47
+
48
+ # 3. Apply chat template and tokenize
49
+
50
+ inputs = tokenizer.apply_chat_template(
51
+ messages,
52
+ add_generation_prompt=True,
53
+ tokenize=True,
54
+ return_dict=True,
55
+ return_tensors="pt"
56
+ ).to(model.device)
57
+
58
+ input_len = inputs["input_ids"].shape[-1]
59
+
60
+ # 4. Generate the response
61
+ with torch.inference_mode():
62
+ generation = model.generate(
63
+ **inputs,
64
+ max_new_tokens=200, # Max length of the generated ICD codes
65
+ do_sample=False, # Use greedy decoding for predictable output
66
+ temperature=0.0, # Zero temperature for deterministic results
67
+ )
68
+
69
+ # 5. Decode the output
70
+ # Extract only the newly generated tokens
71
+ generation = generation[0][input_len:]
72
+ decoded_output = tokenizer.decode(generation, skip_special_tokens=True)
73
+
74
+ return decoded_output.strip()
75
+
76
+ # --- Example Usage ---
77
+ #test_note = "Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."
78
+
79
+
80
+ #print(f"Clinical Note: {test_note}")
81
+ #response = generate_response(test_note)
82
+ #print(f"Generated ICD Codes: {response}")
83
+
84
+
85
+
86
+ pipe = transformers.pipeline(
87
+ "text-generation",
88
+ model=model,
89
+ tokenizer=tokenizer,
90
+ max_new_tokens=50,
91
+ temperature=0.1,
92
+ do_sample=False,
93
+ pad_token_id=tokenizer.eos_token_id,
94
+
95
+ )
96
+ hf_llm = HuggingFacePipeline(pipeline=pipe)
97
+
98
+
99
+ df = pd.read_csv("medical_coding_train_1.csv")
100
+
101
+ documents = [
102
+ Document(
103
+ page_content=f"note: {row['note']}\nicd_code: {row['icd_codes']}",
104
+ metadata={"icd_code": row["icd_codes"]}
105
+ )
106
+ for _, row in df.iterrows()
107
+ ]
108
+
109
+
110
+ # 2. Chunk Documents
111
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
112
+ docs = text_splitter.split_documents(documents)
113
+
114
+ # 3. Create Embeddings and Vector Store (FAISS)
115
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
116
+ db = FAISS.from_documents(docs, embeddings)
117
+ retriever = db.as_retriever(search_kwargs={"k": 2})
118
+
119
+
120
+ RAG_PROMPT_TEMPLATE = """
121
+ You are an expert medical coder.
122
+ Your task is to determine the most accurate ICD-10-CM code for the given clinical note.
123
+
124
+ Use ONLY the following context (which may include ICD codes from similar cases).
125
+ If you cannot determine a match from the provided context, respond exactly with:
126
+ "I cannot find the code in the provided documents."
127
+
128
+ Return ONLY the ICD-10-CM code itself — no explanation, no text, no punctuation.
129
+
130
+ Context:
131
+ {context}
132
+
133
+ Clinical Note:
134
+ {question}
135
+
136
+ ICD-10-CM Code:
137
+ """
138
+
139
+
140
+ rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
141
+ #direct_chain = LLMChain(llm=hf_llm, prompt=rag_prompt)
142
+
143
+
144
+ # 5. Create the QA Chain
145
+ qa_chain = RetrievalQA.from_chain_type(
146
+ llm=hf_llm,
147
+ chain_type="stuff",
148
+ retriever=retriever,
149
+ return_source_documents=False,
150
+ chain_type_kwargs={"prompt": rag_prompt}
151
+ )
152
+
153
+ def extract_icd_code(text):
154
+ # Pattern to match "ICD-10-CM Code:" followed by the code
155
+ pattern = r'ICD-10-CM Code:\s*([A-Z0-9.]+)'
156
+
157
+ match = re.search(pattern, text)
158
+
159
+ if match:
160
+ return match.group(1)
161
+ return None
162
+
163
+ def generate_code_rag(clinical_note, retriever, threshold=0.35):
164
+ """Generates the ICD code using RAG."""
165
+
166
+ # Format the user question for the RAG prompt template
167
+ query = f"Code the following: {clinical_note}"
168
+
169
+ # Step 1: Retrieve docs
170
+ docs_and_scores = db.similarity_search_with_score(query, k=2)
171
+
172
+ # Step 2: Filter by similarity threshold
173
+ relevant_docs = [doc for doc, score in docs_and_scores if score > threshold]
174
+
175
+
176
+
177
+ if relevant_docs:
178
+ #print(qa_chain)
179
+ result = qa_chain({"query": query})["result"]
180
+
181
+ #answer = result['result']
182
+ icd_code = extract_icd_code(result)
183
+ #print(icd_code)
184
+ if icd_code == None:
185
+ print("I got here")
186
+ result = generate_response(clinical_note)
187
+ return result
188
+ else:
189
+ return icd_code
190
+
191
+ # Step 4: Otherwise, use LLM directly (no context)
192
+
193
+
194
+ # Create the Gradio Interface
195
+ gr.Interface(
196
+ fn=generate_code_rag,
197
+ inputs=gr.Textbox(lines=5, label="Enter Clinical Note Here", placeholder="e.g., Patient presented with simple laceration of the left hand."),
198
+ outputs=gr.Textbox(label="Predicted ICD-10 Code"),
199
+ title="ClaimSwift Medical Coding",
200
+ description="",
201
+ examples=[
202
+ ["Benign neoplasm of peripheral nerves and autonomic nervous system of face, head, and neck"],
203
+ ["Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."],
204
+ ["Simple laceration of the left hand without foreign body."],
205
+ ]
206
+ ).launch(server_name="0.0.0.0", server_port=7860)