abnuel commited on
Commit
84bc192
·
verified ·
1 Parent(s): e0c3895

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +206 -0
  3. medical_coding_train_1.csv +3 -0
  4. requirements.txt +12 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ medical_coding_train_1.csv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
5
+ from huggingface_hub import login
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.embeddings import SentenceTransformerEmbeddings
8
+ from langchain.vectorstores import FAISS
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.llms import HuggingFacePipeline
12
+ from langchain_community.document_loaders.csv_loader import CSVLoader
13
+ import transformers
14
+ from langchain.schema import Document
15
+ import gradio as gr
16
+ import re
17
+
18
+
19
+ model = "abnuel/MedGemma-4b-ICD"
20
+
21
+ #tokenizer = AutoTokenizer.from_pretrained("abnuel/MedGemma-4b-ICD")
22
+
23
+ SYSTEM_PROMPT = "You are an expert medical coder. Your task is to analyze the clinical description provided and output only the single, most appropriate ICD-10-CM code. Do not include any text, justification other than the code itself."
24
+
25
+ bnb_config = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_quant_type="nf4",
28
+ bnb_4bit_use_double_quant=True,
29
+ bnb_4bit_compute_dtype=torch.bfloat16
30
+ )
31
+ tokenizer = AutoTokenizer.from_pretrained(model)
32
+ if tokenizer.pad_token is None:
33
+ tokenizer.pad_token = tokenizer.eos_token
34
+
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model,
37
+ quantization_config=bnb_config,
38
+ device_map="auto"
39
+ )
40
+
41
+
42
+ def generate_response(clinical_note):
43
+ messages = [
44
+ {"role": "system", "content": SYSTEM_PROMPT},
45
+ {"role": "user", "content": f"Code the following: {clinical_note}"},
46
+ ]
47
+
48
+ # 3. Apply chat template and tokenize
49
+
50
+ inputs = tokenizer.apply_chat_template(
51
+ messages,
52
+ add_generation_prompt=True,
53
+ tokenize=True,
54
+ return_dict=True,
55
+ return_tensors="pt"
56
+ ).to(model.device)
57
+
58
+ input_len = inputs["input_ids"].shape[-1]
59
+
60
+ # 4. Generate the response
61
+ with torch.inference_mode():
62
+ generation = model.generate(
63
+ **inputs,
64
+ max_new_tokens=200, # Max length of the generated ICD codes
65
+ do_sample=False, # Use greedy decoding for predictable output
66
+ temperature=0.0, # Zero temperature for deterministic results
67
+ )
68
+
69
+ # 5. Decode the output
70
+ # Extract only the newly generated tokens
71
+ generation = generation[0][input_len:]
72
+ decoded_output = tokenizer.decode(generation, skip_special_tokens=True)
73
+
74
+ return decoded_output.strip()
75
+
76
+ # --- Example Usage ---
77
+ #test_note = "Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."
78
+
79
+
80
+ #print(f"Clinical Note: {test_note}")
81
+ #response = generate_response(test_note)
82
+ #print(f"Generated ICD Codes: {response}")
83
+
84
+
85
+
86
+ pipe = transformers.pipeline(
87
+ "text-generation",
88
+ model=model,
89
+ tokenizer=tokenizer,
90
+ max_new_tokens=50,
91
+ temperature=0.1,
92
+ do_sample=False,
93
+ pad_token_id=tokenizer.eos_token_id,
94
+
95
+ )
96
+ hf_llm = HuggingFacePipeline(pipeline=pipe)
97
+
98
+
99
+ df = pd.read_csv("medical_coding_train_1.csv")
100
+
101
+ documents = [
102
+ Document(
103
+ page_content=f"note: {row['note']}\nicd_code: {row['icd_codes']}",
104
+ metadata={"icd_code": row["icd_codes"]}
105
+ )
106
+ for _, row in df.iterrows()
107
+ ]
108
+
109
+
110
+ # 2. Chunk Documents
111
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
112
+ docs = text_splitter.split_documents(documents)
113
+
114
+ # 3. Create Embeddings and Vector Store (FAISS)
115
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
116
+ db = FAISS.from_documents(docs, embeddings)
117
+ retriever = db.as_retriever(search_kwargs={"k": 2})
118
+
119
+
120
+ RAG_PROMPT_TEMPLATE = """
121
+ You are an expert medical coder.
122
+ Your task is to determine the most accurate ICD-10-CM code for the given clinical note.
123
+
124
+ Use ONLY the following context (which may include ICD codes from similar cases).
125
+ If you cannot determine a match from the provided context, respond exactly with:
126
+ "I cannot find the code in the provided documents."
127
+
128
+ Return ONLY the ICD-10-CM code itself — no explanation, no text, no punctuation.
129
+
130
+ Context:
131
+ {context}
132
+
133
+ Clinical Note:
134
+ {question}
135
+
136
+ ICD-10-CM Code:
137
+ """
138
+
139
+
140
+ rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
141
+ #direct_chain = LLMChain(llm=hf_llm, prompt=rag_prompt)
142
+
143
+
144
+ # 5. Create the QA Chain
145
+ qa_chain = RetrievalQA.from_chain_type(
146
+ llm=hf_llm,
147
+ chain_type="stuff",
148
+ retriever=retriever,
149
+ return_source_documents=False,
150
+ chain_type_kwargs={"prompt": rag_prompt}
151
+ )
152
+
153
+ def extract_icd_code(text):
154
+ # Pattern to match "ICD-10-CM Code:" followed by the code
155
+ pattern = r'ICD-10-CM Code:\s*([A-Z0-9.]+)'
156
+
157
+ match = re.search(pattern, text)
158
+
159
+ if match:
160
+ return match.group(1)
161
+ return None
162
+
163
+ def generate_code_rag(clinical_note, retriever, threshold=0.35):
164
+ """Generates the ICD code using RAG."""
165
+
166
+ # Format the user question for the RAG prompt template
167
+ query = f"Code the following: {clinical_note}"
168
+
169
+ # Step 1: Retrieve docs
170
+ docs_and_scores = db.similarity_search_with_score(query, k=2)
171
+
172
+ # Step 2: Filter by similarity threshold
173
+ relevant_docs = [doc for doc, score in docs_and_scores if score > threshold]
174
+
175
+
176
+
177
+ if relevant_docs:
178
+ #print(qa_chain)
179
+ result = qa_chain({"query": query})["result"]
180
+
181
+ #answer = result['result']
182
+ icd_code = extract_icd_code(result)
183
+ #print(icd_code)
184
+ if icd_code == None:
185
+ print("I got here")
186
+ result = generate_response(clinical_note)
187
+ return result
188
+ else:
189
+ return icd_code
190
+
191
+ # Step 4: Otherwise, use LLM directly (no context)
192
+
193
+
194
+ # Create the Gradio Interface
195
+ gr.Interface(
196
+ fn=generate_code_rag,
197
+ inputs=gr.Textbox(lines=5, label="Enter Clinical Note Here", placeholder="e.g., Patient presented with simple laceration of the left hand."),
198
+ outputs=gr.Textbox(label="Predicted ICD-10 Code"),
199
+ title="ClaimSwift Medical Coding",
200
+ description="",
201
+ examples=[
202
+ ["Benign neoplasm of peripheral nerves and autonomic nervous system of face, head, and neck"],
203
+ ["Sudden onset chest pain and shortness of breath. Initial diagnosis points towards unstable angina."],
204
+ ["Simple laceration of the left hand without foreign body."],
205
+ ]
206
+ ).launch(server_name="0.0.0.0", server_port=7860)
medical_coding_train_1.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a812e716b9a718799315f93c2aa4ad7efe0c9b93141cdca94f110aa29653af9f
3
+ size 15109530
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ langchain
3
+ langchain-community
4
+ langchain-core
5
+ transformers
6
+ torch
7
+ accelerate
8
+ bitsandbytes
9
+ sentence-transformers
10
+ faiss-cpu
11
+ gradio
12
+ peft