Ahmad-01 commited on
Commit
bc14f17
Β·
verified Β·
1 Parent(s): 37977a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -53
app.py CHANGED
@@ -1,8 +1,6 @@
1
  # app.py
2
- # Synthetic Patient Records RAG App
3
  # Author: Your Name
4
- # Description: A Retrieval-Augmented Generation (RAG) application for synthetic hospital datasets.
5
- # Runs on Hugging Face Spaces using Gradio + FAISS + Sentence Transformers.
6
 
7
  import os
8
  import pandas as pd
@@ -15,18 +13,15 @@ from transformers import pipeline
15
  # ======================================================
16
  # 1. Dataset Handling
17
  # ======================================================
18
-
19
  DEFAULT_DATA_PATH = "patients.csv"
20
 
21
  def safe_load_csv(path):
22
- """Safely load the dataset from CSV"""
23
  if not os.path.exists(path):
24
- raise FileNotFoundError(f"❌ No dataset found at {path}. Please upload 'patients.csv'.")
25
  df = pd.read_csv(path)
26
  return df
27
 
28
  def preprocess_df(df):
29
- """Cleans and harmonizes column names and fields"""
30
  df = df.copy()
31
  ren = {}
32
  for c in df.columns:
@@ -61,28 +56,16 @@ def preprocess_df(df):
61
  if col in df.columns:
62
  df[col] = pd.to_datetime(df[col], errors="coerce")
63
 
64
- if "Length_of_Stay" not in df.columns or df["Length_of_Stay"].isnull().all():
65
- if "Admission_Date" in df.columns and "Discharge_Date" in df.columns:
66
- df["Length_of_Stay"] = (
67
- (df["Discharge_Date"] - df["Admission_Date"]).dt.days.fillna(0).astype(int)
68
- )
69
- else:
70
- df["Length_of_Stay"] = pd.Series([1] * len(df), index=df.index)
71
 
72
- diag_series = (
73
- df["Diagnosis"].fillna("").astype(str) if "Diagnosis" in df.columns else pd.Series([""] * len(df))
74
- )
75
- treat_series = (
76
- df["Treatment"].fillna("").astype(str) if "Treatment" in df.columns else pd.Series([""] * len(df))
77
- )
78
  if "Notes" not in df.columns:
79
  df["Notes"] = (diag_series + " " + treat_series).str.strip()
80
 
81
  df["Notes"] = df["Notes"].astype(str)
82
- df["Satisfaction_Score"] = pd.to_numeric(
83
- df.get("Satisfaction_Score", pd.Series(np.nan, index=df.index)), errors="coerce"
84
- ).fillna(-1)
85
-
86
  if "Patient_ID" not in df.columns:
87
  df.insert(0, "Patient_ID", range(1, len(df) + 1))
88
  return df.reset_index(drop=True)
@@ -91,10 +74,8 @@ def preprocess_df(df):
91
  # ======================================================
92
  # 2. Embedding + FAISS Setup
93
  # ======================================================
94
-
95
  def build_faiss_index(df, embed_model):
96
- """Build FAISS index from Notes column"""
97
- embeddings = embed_model.encode(df["Notes"].tolist(), convert_to_numpy=True, show_progress_bar=True)
98
  index = faiss.IndexFlatL2(embeddings.shape[1])
99
  index.add(embeddings)
100
  return index, embeddings
@@ -103,16 +84,14 @@ def build_faiss_index(df, embed_model):
103
  # ======================================================
104
  # 3. RAG Query Function
105
  # ======================================================
106
-
107
  def generate_answer(query, df, embed_model, index, generator, top_k=3):
108
- """Retrieve relevant notes and generate LLM summary"""
109
  query_emb = embed_model.encode([query])
110
  _, idxs = index.search(np.array(query_emb).astype("float32"), top_k)
111
  retrieved = df.iloc[idxs[0]]
112
  context = "\n".join(retrieved["Notes"].astype(str).tolist())
113
 
114
  prompt = f"""
115
- You are a hospital data assistant. Use the following patient notes to answer the user's question clearly and concisely.
116
 
117
  Context:
118
  {context}
@@ -121,72 +100,59 @@ Question: {query}
121
 
122
  Answer:
123
  """
124
- result = generator(prompt, max_new_tokens=200, do_sample=True, temperature=0.7)[0]["generated_text"]
125
  return result, retrieved[["Patient_ID", "Department", "Satisfaction_Score", "Length_of_Stay", "Notes"]]
126
 
127
 
128
  # ======================================================
129
- # 4. Gradio Interface Logic
130
  # ======================================================
131
-
132
  def create_interface():
133
- # Load default dataset
134
  df_raw = safe_load_csv(DEFAULT_DATA_PATH)
135
  df = preprocess_df(df_raw)
136
 
137
- # Embeddings + Index
138
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
139
  index, _ = build_faiss_index(df, embed_model)
140
 
141
- # LLM generator
142
- generator = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3", device_map="auto")
143
 
144
  def query_app(user_query, uploaded_file=None):
145
- """Handles user queries and optional dataset uploads"""
146
  local_df = df.copy()
147
  local_index = index
148
 
149
- # If user uploads a dataset
150
  if uploaded_file is not None:
151
  try:
152
  user_df = preprocess_df(pd.read_csv(uploaded_file.name))
153
  local_df = user_df
154
  local_index, _ = build_faiss_index(local_df, embed_model)
155
  except Exception as e:
156
- return f"⚠️ Failed to process uploaded file: {str(e)}", pd.DataFrame()
157
 
158
- # Run RAG pipeline
159
  answer, retrieved = generate_answer(user_query, local_df, embed_model, local_index, generator)
160
  return answer, retrieved
161
 
162
- # Gradio UI
163
  iface = gr.Interface(
164
  fn=query_app,
165
  inputs=[
166
  gr.Textbox(label="πŸ’¬ Ask a question about patient data"),
167
- gr.File(label="πŸ“‚ Upload a custom patient CSV (optional)")
168
  ],
169
  outputs=[
170
  gr.Textbox(label="πŸ€– AI Generated Answer"),
171
  gr.Dataframe(label="πŸ“‹ Retrieved Records")
172
  ],
173
  title="πŸ₯ Synthetic Patient Records RAG App",
174
- description=(
175
- "Upload your patient dataset (or use the default one) and ask natural-language questions.\n"
176
- "Built with Sentence Transformers + FAISS + Mistral 7B."
177
- ),
178
  examples=[
179
  ["Summarize satisfaction trends by department."],
180
- ["Find patients over 65 with stays longer than 10 days."],
181
- ["Generate a synthetic patient summary for a cardiology admission."],
182
- ],
183
  )
184
  return iface
185
 
186
 
187
- # ======================================================
188
- # 5. Run App
189
- # ======================================================
190
  if __name__ == "__main__":
191
  app = create_interface()
192
  app.launch()
 
1
  # app.py
2
+ # Lightweight RAG App for Hugging Face Spaces (CPU-friendly)
3
  # Author: Your Name
 
 
4
 
5
  import os
6
  import pandas as pd
 
13
  # ======================================================
14
  # 1. Dataset Handling
15
  # ======================================================
 
16
  DEFAULT_DATA_PATH = "patients.csv"
17
 
18
  def safe_load_csv(path):
 
19
  if not os.path.exists(path):
20
+ raise FileNotFoundError(f"No dataset found at {path}. Please upload 'patients.csv'.")
21
  df = pd.read_csv(path)
22
  return df
23
 
24
  def preprocess_df(df):
 
25
  df = df.copy()
26
  ren = {}
27
  for c in df.columns:
 
56
  if col in df.columns:
57
  df[col] = pd.to_datetime(df[col], errors="coerce")
58
 
59
+ if "Length_of_Stay" not in df.columns:
60
+ df["Length_of_Stay"] = 1
61
+
62
+ diag_series = df["Diagnosis"].fillna("").astype(str) if "Diagnosis" in df.columns else pd.Series([""] * len(df))
63
+ treat_series = df["Treatment"].fillna("").astype(str) if "Treatment" in df.columns else pd.Series([""] * len(df))
 
 
64
 
 
 
 
 
 
 
65
  if "Notes" not in df.columns:
66
  df["Notes"] = (diag_series + " " + treat_series).str.strip()
67
 
68
  df["Notes"] = df["Notes"].astype(str)
 
 
 
 
69
  if "Patient_ID" not in df.columns:
70
  df.insert(0, "Patient_ID", range(1, len(df) + 1))
71
  return df.reset_index(drop=True)
 
74
  # ======================================================
75
  # 2. Embedding + FAISS Setup
76
  # ======================================================
 
77
  def build_faiss_index(df, embed_model):
78
+ embeddings = embed_model.encode(df["Notes"].tolist(), convert_to_numpy=True, show_progress_bar=False)
 
79
  index = faiss.IndexFlatL2(embeddings.shape[1])
80
  index.add(embeddings)
81
  return index, embeddings
 
84
  # ======================================================
85
  # 3. RAG Query Function
86
  # ======================================================
 
87
  def generate_answer(query, df, embed_model, index, generator, top_k=3):
 
88
  query_emb = embed_model.encode([query])
89
  _, idxs = index.search(np.array(query_emb).astype("float32"), top_k)
90
  retrieved = df.iloc[idxs[0]]
91
  context = "\n".join(retrieved["Notes"].astype(str).tolist())
92
 
93
  prompt = f"""
94
+ You are a hospital assistant. Use the following context to answer the question.
95
 
96
  Context:
97
  {context}
 
100
 
101
  Answer:
102
  """
103
+ result = generator(prompt, max_new_tokens=200)[0]["generated_text"]
104
  return result, retrieved[["Patient_ID", "Department", "Satisfaction_Score", "Length_of_Stay", "Notes"]]
105
 
106
 
107
  # ======================================================
108
+ # 4. Gradio Interface
109
  # ======================================================
 
110
  def create_interface():
 
111
  df_raw = safe_load_csv(DEFAULT_DATA_PATH)
112
  df = preprocess_df(df_raw)
113
 
 
114
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
115
  index, _ = build_faiss_index(df, embed_model)
116
 
117
+ # βœ… Use lightweight model (works on CPU)
118
+ generator = pipeline("text2text-generation", model="google/flan-t5-base")
119
 
120
  def query_app(user_query, uploaded_file=None):
 
121
  local_df = df.copy()
122
  local_index = index
123
 
 
124
  if uploaded_file is not None:
125
  try:
126
  user_df = preprocess_df(pd.read_csv(uploaded_file.name))
127
  local_df = user_df
128
  local_index, _ = build_faiss_index(local_df, embed_model)
129
  except Exception as e:
130
+ return f"⚠️ Error loading file: {e}", pd.DataFrame()
131
 
 
132
  answer, retrieved = generate_answer(user_query, local_df, embed_model, local_index, generator)
133
  return answer, retrieved
134
 
 
135
  iface = gr.Interface(
136
  fn=query_app,
137
  inputs=[
138
  gr.Textbox(label="πŸ’¬ Ask a question about patient data"),
139
+ gr.File(label="πŸ“‚ Upload a patient CSV (optional)")
140
  ],
141
  outputs=[
142
  gr.Textbox(label="πŸ€– AI Generated Answer"),
143
  gr.Dataframe(label="πŸ“‹ Retrieved Records")
144
  ],
145
  title="πŸ₯ Synthetic Patient Records RAG App",
146
+ description="Ask natural-language questions about synthetic hospital data. Powered by Sentence Transformers + Flan-T5.",
 
 
 
147
  examples=[
148
  ["Summarize satisfaction trends by department."],
149
+ ["Find patients older than 65 with long hospital stays."],
150
+ ["Generate a summary of cardiology patients."]
151
+ ]
152
  )
153
  return iface
154
 
155
 
 
 
 
156
  if __name__ == "__main__":
157
  app = create_interface()
158
  app.launch()