Spaces:
Runtime error
Runtime error
Commit ·
902fa1b
0
Parent(s):
first
Browse files- Dockerfile +18 -0
- __pycache__/api.cpython-311.pyc +0 -0
- __pycache__/model.cpython-311.pyc +0 -0
- data/patient_data.csv +101 -0
- data/processed_patient_data.csv +101 -0
- docker_entrypoint.sh +10 -0
- models/best_vae_model.pth +0 -0
- models/encoders.pkl +0 -0
- models/feature_names.pkl +0 -0
- models/scaler.pkl +0 -0
- models/vae_model.pth +0 -0
- readme.md +431 -0
- requirements.txt +7 -0
- src/__pycache__/api.cpython-311.pyc +0 -0
- src/__pycache__/model.cpython-311.pyc +0 -0
- src/api.py +290 -0
- src/continual_train.py +27 -0
- src/continual_train_loop.py +29 -0
- src/data_preprocessing.py +122 -0
- src/model.py +57 -0
- src/train.py +158 -0
- src/web_scraper.py +58 -0
- tests/test_api.py +27 -0
Dockerfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY . .
|
| 6 |
+
|
| 7 |
+
RUN pip install --upgrade pip && \
|
| 8 |
+
pip install -r requirements.txt && \
|
| 9 |
+
pip install beautifulsoup4 lxml
|
| 10 |
+
|
| 11 |
+
RUN mkdir -p data models
|
| 12 |
+
|
| 13 |
+
EXPOSE 8000
|
| 14 |
+
|
| 15 |
+
COPY docker_entrypoint.sh /app/docker_entrypoint.sh
|
| 16 |
+
RUN chmod +x /app/docker_entrypoint.sh
|
| 17 |
+
|
| 18 |
+
CMD ["/app/docker_entrypoint.sh"]
|
__pycache__/api.cpython-311.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (4.22 kB). View file
|
|
|
data/patient_data.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PatientID,FirstName,LastName,Gender,Age,DateOfBirth,Diagnosis,BloodType,AdmissionDate,DischargeDate
|
| 2 |
+
1001,Brittney,Davies,Female,32,1992-10-24,Migraine,B+,2024-06-19,2025-05-10
|
| 3 |
+
1002,Aaron,Johnson,Male,93,1931-08-09,Diabetes,O+,2025-02-09,2025-03-05
|
| 4 |
+
1003,Nathan,Romero,Male,73,1952-03-22,Asthma,O+,2025-03-15,2025-05-25
|
| 5 |
+
1004,Corey,Garcia,Male,84,1940-08-21,Pneumonia,AB-,2024-11-09,2024-12-03
|
| 6 |
+
1005,Brandon,Scott,Male,42,1983-02-20,Diabetes,A-,2024-12-17,2025-02-12
|
| 7 |
+
1006,Patricia,Bernard,Female,82,1942-08-13,Arthritis,AB+,2025-01-21,2025-05-26
|
| 8 |
+
1007,Patrick,Sandoval,Male,92,1932-06-26,Heart Disease,AB+,2025-01-09,2025-04-09
|
| 9 |
+
1008,Heather,Hughes,Female,24,2000-10-18,Heart Disease,O+,2025-01-31,2025-03-10
|
| 10 |
+
1009,Gina,Kline,Female,35,1989-10-05,Arthritis,AB-,2024-11-22,2025-01-20
|
| 11 |
+
1010,Cory,Turner,Male,44,1981-01-07,Cancer,B-,2024-10-04,2024-12-07
|
| 12 |
+
1011,Mary,Ray,Female,44,1981-04-13,Fracture,A-,2024-08-26,2024-11-12
|
| 13 |
+
1012,Jennifer,Young,Female,80,1944-11-05,Arthritis,B+,2025-06-04,2025-06-25
|
| 14 |
+
1013,Patricia,Johnson,Female,73,1951-06-19,Fracture,O-,2024-09-03,2024-09-08
|
| 15 |
+
1014,Matthew,Davis,Male,8,2016-09-17,Arthritis,B+,2025-01-25,2025-05-09
|
| 16 |
+
1015,Jason,Dodson,Male,62,1962-08-31,COVID-19,A-,2025-03-06,2025-04-27
|
| 17 |
+
1016,Thomas,Baker,Male,97,1927-07-05,COVID-19,B+,2025-01-09,2025-03-29
|
| 18 |
+
1017,Leonard,Cochran,Male,82,1943-05-06,Asthma,AB+,2025-02-28,2025-04-04
|
| 19 |
+
1018,Jessica,Pearson,Female,62,1963-01-01,Heart Disease,B-,2024-07-05,2025-06-11
|
| 20 |
+
1019,Robin,Johnson,Female,65,1960-05-03,Heart Disease,O-,2025-05-13,2025-05-27
|
| 21 |
+
1020,Gary,Hill,Male,87,1938-05-31,Fracture,B+,2024-08-27,2025-03-04
|
| 22 |
+
1021,Kevin,Lee,Male,81,1944-06-06,Pneumonia,O+,2024-10-25,2025-01-09
|
| 23 |
+
1022,Robert,Powell,Male,68,1956-07-04,Migraine,A-,2025-03-11,2025-03-12
|
| 24 |
+
1023,Danielle,Wright,Female,42,1983-01-18,Heart Disease,B+,2025-01-20,2025-06-16
|
| 25 |
+
1024,Hannah,Fields,Female,23,2001-07-25,Fracture,AB-,2025-04-30,2025-06-19
|
| 26 |
+
1025,Adam,Barton,Male,54,1970-08-11,Migraine,B-,2024-08-15,2024-10-19
|
| 27 |
+
1026,Paula,Cochran,Female,83,1942-05-14,Diabetes,B+,2024-07-22,2025-02-12
|
| 28 |
+
1027,Lisa,Christensen,Female,50,1974-10-15,COVID-19,B-,2025-02-24,2025-05-07
|
| 29 |
+
1028,Joseph,Huff,Male,81,1944-01-04,Fracture,AB+,2025-03-31,2025-06-06
|
| 30 |
+
1029,Steven,Spears,Male,99,1926-06-15,COVID-19,A-,2024-11-21,2025-01-04
|
| 31 |
+
1030,Jennifer,Suarez,Female,95,1929-09-11,Asthma,AB+,2024-10-16,2025-06-10
|
| 32 |
+
1031,Benjamin,Ross,Male,86,1938-07-22,Migraine,AB-,2024-10-01,2025-03-03
|
| 33 |
+
1032,Michael,Cunningham,Male,70,1954-08-31,Pneumonia,B-,2025-03-09,2025-04-18
|
| 34 |
+
1033,Tammy,Bullock,Female,39,1986-04-08,Pneumonia,O+,2025-03-16,2025-05-12
|
| 35 |
+
1034,Nicolas,Harrison,Male,73,1951-11-07,Diabetes,AB+,2024-08-27,2024-09-21
|
| 36 |
+
1035,Richard,Cortez,Male,44,1980-11-20,Fracture,B+,2025-02-15,2025-04-05
|
| 37 |
+
1036,Jordan,Hernandez,Female,74,1950-07-01,Pneumonia,B-,2024-09-30,2025-05-06
|
| 38 |
+
1037,Ralph,Newman,Male,14,2010-08-08,Hypertension,B+,2024-09-25,2025-01-31
|
| 39 |
+
1038,Renee,Morrison,Female,63,1961-11-14,Pneumonia,A+,2025-03-08,2025-04-13
|
| 40 |
+
1039,Karen,Clark,Female,91,1933-08-10,Heart Disease,O-,2024-09-30,2025-05-03
|
| 41 |
+
1040,Lisa,Benson,Female,36,1989-04-04,Asthma,A-,2024-07-19,2025-01-13
|
| 42 |
+
1041,Randall,May,Male,52,1973-05-06,Pneumonia,AB-,2025-01-13,2025-02-16
|
| 43 |
+
1042,Michael,Baker,Male,14,2010-08-28,Fracture,B-,2025-02-23,2025-03-12
|
| 44 |
+
1043,Todd,Harrington,Male,21,2003-10-26,COVID-19,O-,2025-01-18,2025-02-18
|
| 45 |
+
1044,Terry,Bryan,Male,44,1980-06-22,COVID-19,A+,2025-05-07,2025-05-31
|
| 46 |
+
1045,Sharon,Smith,Female,83,1942-01-01,Pneumonia,A+,2025-04-01,2025-05-12
|
| 47 |
+
1046,David,Ray,Male,70,1954-11-26,Diabetes,A-,2024-10-09,2025-04-09
|
| 48 |
+
1047,Danielle,Dominguez,Female,66,1959-05-25,COVID-19,AB-,2024-09-18,2024-10-08
|
| 49 |
+
1048,Victoria,Johnson,Female,5,2019-10-05,Hypertension,A+,2025-04-02,2025-06-08
|
| 50 |
+
1049,Robert,Ferguson,Male,68,1956-06-29,Hypertension,B+,2024-08-19,2025-01-31
|
| 51 |
+
1050,Brandon,Hall,Male,71,1954-03-05,COVID-19,B-,2024-11-22,2025-05-02
|
| 52 |
+
1051,Brittany,Bailey,Female,67,1958-01-24,Asthma,A-,2025-02-04,2025-03-07
|
| 53 |
+
1052,Kyle,Ryan,Male,49,1976-05-07,Migraine,AB-,2024-09-13,2024-11-13
|
| 54 |
+
1053,Andrew,Smith,Male,91,1933-07-29,Asthma,B-,2024-12-12,2025-02-11
|
| 55 |
+
1054,Richard,Williams,Male,4,2021-02-14,Pneumonia,O-,2024-09-23,2024-09-30
|
| 56 |
+
1055,Aaron,Walton,Male,58,1967-04-21,Cancer,B-,2024-09-23,2025-05-31
|
| 57 |
+
1056,Stephanie,Johnson,Female,44,1980-12-07,Asthma,O+,2024-08-08,2025-01-11
|
| 58 |
+
1057,Benjamin,Mitchell,Male,56,1968-09-19,Heart Disease,O+,2025-01-14,2025-02-14
|
| 59 |
+
1058,Jeffrey,Spence,Male,6,2018-09-29,Migraine,B-,2024-08-23,2025-02-01
|
| 60 |
+
1059,Jamie,Russell,Female,32,1992-09-08,Cancer,A+,2025-04-04,2025-06-26
|
| 61 |
+
1060,Alexander,Hernandez,Male,37,1987-08-04,Heart Disease,B+,2024-11-18,2025-04-27
|
| 62 |
+
1061,Nicole,Gibson,Female,54,1970-11-10,Fracture,A+,2025-05-28,2025-06-13
|
| 63 |
+
1062,Juan,Thompson,Male,11,2014-03-18,Migraine,AB+,2024-12-30,2025-06-10
|
| 64 |
+
1063,Duane,West,Male,27,1998-02-14,Diabetes,A-,2024-09-22,2024-10-05
|
| 65 |
+
1064,Natalie,Lee,Female,37,1988-05-04,Hypertension,B+,2025-02-17,2025-05-02
|
| 66 |
+
1065,James,Liu,Male,37,1987-12-28,Heart Disease,B-,2024-06-27,2024-11-01
|
| 67 |
+
1066,Pam,Baker,Female,40,1985-05-23,Fracture,B+,2025-02-17,2025-03-27
|
| 68 |
+
1067,David,Williams,Male,27,1997-09-16,Fracture,AB-,2025-02-12,2025-04-16
|
| 69 |
+
1068,Mark,Gray,Male,66,1959-01-27,COVID-19,B+,2025-01-26,2025-06-22
|
| 70 |
+
1069,Jessica,Cannon,Female,98,1926-11-23,Migraine,A-,2024-09-07,2025-05-21
|
| 71 |
+
1070,Jason,Roberts,Male,28,1997-04-02,Arthritis,AB-,2024-12-28,2025-05-18
|
| 72 |
+
1071,Robert,Walker,Male,97,1928-03-08,Pneumonia,A-,2025-05-19,2025-06-12
|
| 73 |
+
1072,Crystal,Williams,Female,22,2002-10-17,Hypertension,O-,2024-09-06,2024-10-25
|
| 74 |
+
1073,Emily,Thomas,Female,10,2015-05-16,COVID-19,AB-,2024-08-31,2025-03-02
|
| 75 |
+
1074,Mark,White,Male,46,1979-01-14,Diabetes,AB+,2025-03-30,2025-06-01
|
| 76 |
+
1075,Jillian,Lucas,Female,47,1977-10-13,Diabetes,A+,2024-10-06,2024-12-23
|
| 77 |
+
1076,Jeffrey,Nguyen,Male,23,2001-09-06,Arthritis,B+,2025-01-24,2025-04-30
|
| 78 |
+
1077,Wendy,Nguyen,Female,13,2012-01-20,Cancer,O-,2024-09-18,2025-01-07
|
| 79 |
+
1078,Lance,Miller,Male,47,1977-09-10,Fracture,AB+,2024-10-12,2025-05-20
|
| 80 |
+
1079,William,Andrews,Male,69,1955-11-03,COVID-19,O+,2024-06-23,2024-12-19
|
| 81 |
+
1080,Heather,King,Female,40,1984-10-13,Heart Disease,B-,2025-04-27,2025-05-19
|
| 82 |
+
1081,Kayla,Fields,Female,25,2000-05-02,Arthritis,AB-,2025-01-04,2025-05-02
|
| 83 |
+
1082,Sarah,Kline,Female,62,1962-07-04,Hypertension,B+,2024-07-20,2024-09-17
|
| 84 |
+
1083,Jeremy,Miller,Male,24,2001-06-02,COVID-19,B-,2025-04-08,2025-06-13
|
| 85 |
+
1084,Melissa,Wallace,Female,70,1955-03-01,Migraine,O+,2024-12-27,2025-05-31
|
| 86 |
+
1085,John,Weiss,Male,67,1958-03-03,Pneumonia,AB-,2025-05-19,2025-05-26
|
| 87 |
+
1086,Darren,Herrera,Male,11,2014-04-08,Asthma,A-,2024-07-04,2024-09-18
|
| 88 |
+
1087,Megan,Kelly,Female,9,2016-03-14,Asthma,A-,2024-06-24,2024-12-10
|
| 89 |
+
1088,John,Robertson,Male,8,2017-04-20,Heart Disease,O-,2024-08-18,2025-04-24
|
| 90 |
+
1089,Michelle,Robles,Female,19,2006-06-09,Diabetes,O-,2024-11-29,2025-06-17
|
| 91 |
+
1090,Lisa,Wright,Female,12,2012-07-14,Cancer,B-,2025-01-21,2025-03-05
|
| 92 |
+
1091,Melissa,Reynolds,Female,12,2013-03-18,Hypertension,O-,2024-09-29,2024-10-24
|
| 93 |
+
1092,Kristen,Sanders,Female,48,1976-10-23,Heart Disease,AB+,2025-05-22,2025-06-22
|
| 94 |
+
1093,Timothy,Short,Male,2,2022-08-08,Heart Disease,B+,2024-12-07,2025-01-12
|
| 95 |
+
1094,Gene,Greene,Male,91,1933-11-12,Diabetes,AB+,2024-09-11,2024-11-19
|
| 96 |
+
1095,Jennifer,Brown,Female,2,2023-01-22,Diabetes,A+,2024-07-09,2025-06-17
|
| 97 |
+
1096,Jeremy,Wright,Male,45,1980-03-29,Cancer,O+,2025-01-25,2025-02-07
|
| 98 |
+
1097,Patricia,Pierce,Female,58,1967-01-20,COVID-19,A+,2024-12-27,2025-06-16
|
| 99 |
+
1098,Cindy,Brooks,Female,40,1984-12-06,Asthma,A+,2024-07-19,2025-04-08
|
| 100 |
+
1099,Seth,Dawson,Male,30,1995-03-05,Diabetes,AB-,2024-12-20,2025-03-30
|
| 101 |
+
1100,Gene,Kelly,Male,70,1954-08-22,Arthritis,AB-,2024-09-05,2025-02-02
|
data/processed_patient_data.csv
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
age,gender,diagnosis,blood_type,length_of_stay,age_group,admission_season,admission_day,admission_month,admission_year
|
| 2 |
+
32,0,8,4,325,1,1,2,5,4
|
| 3 |
+
93,1,4,6,24,4,0,6,1,5
|
| 4 |
+
73,1,1,6,71,4,0,5,2,5
|
| 5 |
+
84,1,9,3,24,4,3,5,10,4
|
| 6 |
+
42,1,4,1,57,2,3,1,11,4
|
| 7 |
+
82,0,0,2,125,4,0,1,0,5
|
| 8 |
+
92,1,6,2,90,4,0,3,0,5
|
| 9 |
+
24,0,6,6,38,1,0,4,0,5
|
| 10 |
+
35,0,0,3,59,1,3,4,10,4
|
| 11 |
+
44,1,3,5,64,2,3,4,9,4
|
| 12 |
+
44,0,5,1,78,2,2,0,7,4
|
| 13 |
+
80,0,0,4,21,4,1,2,5,5
|
| 14 |
+
73,0,5,7,5,4,2,1,8,4
|
| 15 |
+
8,1,0,4,104,0,0,5,0,5
|
| 16 |
+
62,1,2,1,52,3,0,3,2,5
|
| 17 |
+
97,1,2,4,79,4,0,3,0,5
|
| 18 |
+
82,1,1,2,35,4,0,4,1,5
|
| 19 |
+
62,0,6,5,341,3,2,4,6,4
|
| 20 |
+
65,0,6,7,14,3,1,1,4,5
|
| 21 |
+
87,1,5,4,189,4,2,1,7,4
|
| 22 |
+
81,1,9,6,76,4,3,4,9,4
|
| 23 |
+
68,1,8,1,1,4,0,1,2,5
|
| 24 |
+
42,0,6,4,147,2,0,0,0,5
|
| 25 |
+
23,0,5,3,50,1,1,2,3,5
|
| 26 |
+
54,1,8,5,65,3,2,3,7,4
|
| 27 |
+
83,0,4,4,205,4,2,0,6,4
|
| 28 |
+
50,0,2,5,72,2,0,0,1,5
|
| 29 |
+
81,1,5,2,67,4,0,0,2,5
|
| 30 |
+
99,1,2,1,44,4,3,3,10,4
|
| 31 |
+
95,0,1,2,237,4,3,2,9,4
|
| 32 |
+
86,1,8,3,153,4,3,1,9,4
|
| 33 |
+
70,1,9,5,40,4,0,6,2,5
|
| 34 |
+
39,0,9,6,57,2,0,6,2,5
|
| 35 |
+
73,1,4,2,25,4,2,1,7,4
|
| 36 |
+
44,1,5,4,49,2,0,5,1,5
|
| 37 |
+
74,0,9,5,218,4,2,0,8,4
|
| 38 |
+
14,1,7,4,128,0,2,2,8,4
|
| 39 |
+
63,0,9,0,36,3,0,5,2,5
|
| 40 |
+
91,0,6,7,215,4,2,0,8,4
|
| 41 |
+
36,0,1,1,178,2,2,4,6,4
|
| 42 |
+
52,1,9,3,34,3,0,0,0,5
|
| 43 |
+
14,1,5,5,17,0,0,6,1,5
|
| 44 |
+
21,1,2,7,31,1,0,5,0,5
|
| 45 |
+
44,1,2,0,24,2,1,2,4,5
|
| 46 |
+
83,0,9,0,41,4,1,1,3,5
|
| 47 |
+
70,1,4,1,182,4,3,2,9,4
|
| 48 |
+
66,0,2,3,20,4,2,2,8,4
|
| 49 |
+
5,0,7,0,67,0,1,2,3,5
|
| 50 |
+
68,1,7,4,165,4,2,0,7,4
|
| 51 |
+
71,1,2,5,161,4,3,4,10,4
|
| 52 |
+
67,0,1,1,31,4,0,1,1,5
|
| 53 |
+
49,1,8,3,61,2,2,4,8,4
|
| 54 |
+
91,1,1,5,61,4,3,3,11,4
|
| 55 |
+
4,1,9,7,7,0,2,0,8,4
|
| 56 |
+
58,1,3,5,250,3,2,0,8,4
|
| 57 |
+
44,0,1,6,156,2,2,3,7,4
|
| 58 |
+
56,1,6,6,31,3,0,1,0,5
|
| 59 |
+
6,1,8,5,162,0,2,4,7,4
|
| 60 |
+
32,0,3,0,83,1,1,4,3,5
|
| 61 |
+
37,1,6,4,160,2,3,0,10,4
|
| 62 |
+
54,0,5,0,16,3,1,2,4,5
|
| 63 |
+
11,1,8,2,162,0,3,0,11,4
|
| 64 |
+
27,1,4,1,13,1,2,6,8,4
|
| 65 |
+
37,0,7,4,74,2,0,0,1,5
|
| 66 |
+
37,1,6,5,127,2,1,3,5,4
|
| 67 |
+
40,0,5,4,38,2,0,0,1,5
|
| 68 |
+
27,1,5,3,63,1,0,2,1,5
|
| 69 |
+
66,1,2,4,147,4,0,6,0,5
|
| 70 |
+
98,0,8,1,256,4,2,5,8,4
|
| 71 |
+
28,1,0,3,141,1,3,5,11,4
|
| 72 |
+
97,1,9,1,24,4,1,0,4,5
|
| 73 |
+
22,0,7,7,49,1,2,4,8,4
|
| 74 |
+
10,0,2,3,183,0,2,5,7,4
|
| 75 |
+
46,1,4,2,63,2,0,6,2,5
|
| 76 |
+
47,0,4,0,78,2,3,6,9,4
|
| 77 |
+
23,1,0,4,96,1,0,4,0,5
|
| 78 |
+
13,0,3,7,111,0,2,2,8,4
|
| 79 |
+
47,1,5,2,220,2,3,5,9,4
|
| 80 |
+
69,1,2,6,179,4,1,6,5,4
|
| 81 |
+
40,0,6,5,22,2,1,6,3,5
|
| 82 |
+
25,0,0,3,118,1,0,5,0,5
|
| 83 |
+
62,0,7,4,59,3,2,5,6,4
|
| 84 |
+
24,1,2,5,66,1,1,1,3,5
|
| 85 |
+
70,0,8,6,155,4,3,4,11,4
|
| 86 |
+
67,1,9,3,7,4,1,0,4,5
|
| 87 |
+
11,1,1,1,76,0,2,3,6,4
|
| 88 |
+
9,0,1,1,169,0,1,0,5,4
|
| 89 |
+
8,1,6,7,249,0,2,6,7,4
|
| 90 |
+
19,0,4,7,200,1,3,4,10,4
|
| 91 |
+
12,0,3,5,43,0,0,1,0,5
|
| 92 |
+
12,0,7,7,25,0,2,6,8,4
|
| 93 |
+
48,0,6,2,31,2,1,3,4,5
|
| 94 |
+
2,1,6,4,36,0,3,5,11,4
|
| 95 |
+
91,1,4,2,69,4,2,2,8,4
|
| 96 |
+
2,0,4,0,343,0,2,1,6,4
|
| 97 |
+
45,1,3,6,13,2,0,5,0,5
|
| 98 |
+
58,0,2,0,171,3,3,4,11,4
|
| 99 |
+
40,0,1,0,263,2,2,4,6,4
|
| 100 |
+
30,1,4,3,100,1,3,4,11,4
|
| 101 |
+
70,1,0,3,150,4,2,3,8,4
|
docker_entrypoint.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
|
| 3 |
+
# Start continual training loop in background
|
| 4 |
+
python src/continual_train_loop.py &
|
| 5 |
+
|
| 6 |
+
# Start web scraper in background
|
| 7 |
+
python src/web_scraper.py &
|
| 8 |
+
|
| 9 |
+
# Start FastAPI server (foreground)
|
| 10 |
+
uvicorn src.api:app --host 0.0.0.0 --port 8000 --reload
|
models/best_vae_model.pth
ADDED
|
Binary file (13.7 kB). View file
|
|
|
models/encoders.pkl
ADDED
|
Binary file (1.1 kB). View file
|
|
|
models/feature_names.pkl
ADDED
|
Binary file (155 Bytes). View file
|
|
|
models/scaler.pkl
ADDED
|
Binary file (855 Bytes). View file
|
|
|
models/vae_model.pth
ADDED
|
Binary file (13.6 kB). View file
|
|
|
readme.md
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Healthcare Synthetic Data VAE
|
| 3 |
+
emoji: "🏥"
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: "3.41.2"
|
| 8 |
+
app_file: src/api.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
# Healthcare Synthetic Data Generation using VAE
|
| 12 |
+
|
| 13 |
+
A complete pipeline for generating synthetic healthcare data using Variational Autoencoders (VAE) with FastAPI serving capabilities.
|
| 14 |
+
|
| 15 |
+
## 🏥 Project Overview
|
| 16 |
+
|
| 17 |
+
This project implements a **Variational Autoencoder (VAE)** to generate synthetic patient data for healthcare AI applications. The system can create realistic patient records while preserving privacy and statistical properties of the original data.
|
| 18 |
+
|
| 19 |
+
### Key Features
|
| 20 |
+
- 🔬 **Medical Data Generation**: Creates synthetic patient records with realistic correlations
|
| 21 |
+
- 🔒 **Privacy-Preserving**: No direct storage of original patient data
|
| 22 |
+
- 🚀 **Production Ready**: FastAPI deployment with RESTful endpoints
|
| 23 |
+
- 📊 **Quality Validation**: Built-in data quality metrics and evaluation
|
| 24 |
+
- 🎛️ **Configurable**: Easy hyperparameter tuning and model customization
|
| 25 |
+
|
| 26 |
+
## 🏗️ Project Structure
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
healthcare-vae/
|
| 30 |
+
├── README.md # This file
|
| 31 |
+
├── requirements.txt # Python dependencies
|
| 32 |
+
├── data/
|
| 33 |
+
│ ├── raw_patient_data.csv # Original patient data (not included)
|
| 34 |
+
│ └── processed_patient_data.csv # Preprocessed features
|
| 35 |
+
├── src/
|
| 36 |
+
│ ├── model.py # VAE architecture
|
| 37 |
+
│ ├── data_preprocessing.py # Data cleaning and feature engineering
|
| 38 |
+
│ ├── train.py # Model training script
|
| 39 |
+
│ ├── evaluate.py # Model evaluation and metrics
|
| 40 |
+
│ └── api.py # FastAPI serving application
|
| 41 |
+
├── models/
|
| 42 |
+
│ ├── vae_model.pth # Trained VAE weights
|
| 43 |
+
│ ├── best_vae_model.pth # Best model checkpoint
|
| 44 |
+
│ ├── scaler.pkl # Data preprocessing scaler
|
| 45 |
+
│ └── feature_names.pkl # Feature column names
|
| 46 |
+
├── notebooks/
|
| 47 |
+
│ ├── data_exploration.ipynb # Data analysis and visualization
|
| 48 |
+
│ └── model_analysis.ipynb # Model performance analysis
|
| 49 |
+
└── tests/
|
| 50 |
+
├── test_model.py # Model unit tests
|
| 51 |
+
└── test_api.py # API endpoint tests
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## 🧠 How VAE Works for Healthcare Data
|
| 55 |
+
|
| 56 |
+
### Mathematical Foundation
|
| 57 |
+
|
| 58 |
+
**Variational Autoencoder (VAE)** learns a compressed representation of patient data:
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
Patient Data → Encoder → Latent Space (μ, σ) → Decoder → Synthetic Patient
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
**Key Components:**
|
| 65 |
+
1. **Encoder**: Maps patient features to latent space parameters
|
| 66 |
+
2. **Latent Space**: Continuous representation of patient "types"
|
| 67 |
+
3. **Decoder**: Generates new patients from latent codes
|
| 68 |
+
4. **Loss Function**: Reconstruction + KL Divergence
|
| 69 |
+
|
| 70 |
+
### Training Process
|
| 71 |
+
|
| 72 |
+
```mermaid
|
| 73 |
+
graph TD
|
| 74 |
+
A[Raw Patient Data] --> B[Data Preprocessing]
|
| 75 |
+
B --> C[Feature Engineering]
|
| 76 |
+
C --> D[Train/Validation Split]
|
| 77 |
+
D --> E[VAE Training]
|
| 78 |
+
E --> F[Model Validation]
|
| 79 |
+
F --> G{Good Performance?}
|
| 80 |
+
G -->|No| H[Adjust Hyperparameters]
|
| 81 |
+
H --> E
|
| 82 |
+
G -->|Yes| I[Save Best Model]
|
| 83 |
+
I --> J[Deploy API]
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## 🚀 Quick Start
|
| 87 |
+
|
| 88 |
+
### 1. Installation
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
# Clone repository
|
| 92 |
+
git clone https://github.com/theaniketgiri/healthcare-vae.git
|
| 93 |
+
cd healthcare-vae
|
| 94 |
+
|
| 95 |
+
# Install dependencies
|
| 96 |
+
pip install -r requirements.txt
|
| 97 |
+
|
| 98 |
+
# Create necessary directories
|
| 99 |
+
mkdir -p data models notebooks tests
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
### 2. Data Preparation
|
| 103 |
+
|
| 104 |
+
```python
|
| 105 |
+
# Create sample patient data (for demonstration)
|
| 106 |
+
python -c "
|
| 107 |
+
import pandas as pd
|
| 108 |
+
import numpy as np
|
| 109 |
+
|
| 110 |
+
np.random.seed(42)
|
| 111 |
+
n_patients = 1000
|
| 112 |
+
|
| 113 |
+
# Generate synthetic patient data
|
| 114 |
+
data = {
|
| 115 |
+
'patient_id': range(1, n_patients + 1),
|
| 116 |
+
'age': np.random.normal(50, 15, n_patients).clip(18, 90),
|
| 117 |
+
'gender': np.random.choice(['M', 'F'], n_patients),
|
| 118 |
+
'bmi': np.random.normal(25, 5, n_patients).clip(15, 50),
|
| 119 |
+
'systolic_bp': np.random.normal(120, 20, n_patients).clip(80, 200),
|
| 120 |
+
'diastolic_bp': np.random.normal(80, 15, n_patients).clip(50, 120),
|
| 121 |
+
'diabetes': np.random.choice([0, 1], n_patients, p=[0.8, 0.2]),
|
| 122 |
+
'cholesterol': np.random.normal(200, 40, n_patients).clip(100, 400),
|
| 123 |
+
'heart_rate': np.random.normal(72, 12, n_patients).clip(50, 120),
|
| 124 |
+
'smoking': np.random.choice([0, 1], n_patients, p=[0.7, 0.3]),
|
| 125 |
+
'family_history': np.random.choice([0, 1], n_patients, p=[0.6, 0.4])
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
df = pd.DataFrame(data)
|
| 129 |
+
df.to_csv('data/raw_patient_data.csv', index=False)
|
| 130 |
+
print('Sample data created!')
|
| 131 |
+
"
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
### 3. Train the Model
|
| 135 |
+
|
| 136 |
+
```bash
|
| 137 |
+
# Preprocess data and train VAE
|
| 138 |
+
python src/train.py
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
### 4. Start the API
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
# Launch FastAPI server
|
| 145 |
+
uvicorn src.api:app --reload --host 0.0.0.0 --port 8000
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### 5. Generate Synthetic Data
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
# Test the API
|
| 152 |
+
curl -X POST "http://localhost:8000/generate" \
|
| 153 |
+
-H "Content-Type: application/json" \
|
| 154 |
+
-d '{"n_samples": 10, "random_seed": 42}'
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
## 📊 Data Flow Explanation
|
| 158 |
+
|
| 159 |
+
### Phase 1: Data Preprocessing (`data_preprocessing.py`)
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
Raw Patient Data → Feature Engineering → Normalization → Processed Data
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
**Operations:**
|
| 166 |
+
- **Missing Value Handling**: Imputation strategies for clinical data
|
| 167 |
+
- **Categorical Encoding**: One-hot encoding for gender, diagnosis codes
|
| 168 |
+
- **Feature Scaling**: StandardScaler for numerical stability
|
| 169 |
+
- **Outlier Detection**: Medical range validation
|
| 170 |
+
- **Feature Engineering**: BMI categories, age groups, risk scores
|
| 171 |
+
|
| 172 |
+
### Phase 2: Model Architecture (`model.py`)
|
| 173 |
+
|
| 174 |
+
**VAE Architecture for Healthcare:**
|
| 175 |
+
```
|
| 176 |
+
Input Layer (n_features)
|
| 177 |
+
↓
|
| 178 |
+
Encoder Hidden Layers (64→32→16)
|
| 179 |
+
↓
|
| 180 |
+
Latent Space (μ, σ) - 8 dimensions
|
| 181 |
+
↓
|
| 182 |
+
Decoder Hidden Layers (16→32→64)
|
| 183 |
+
↓
|
| 184 |
+
Output Layer (n_features)
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
**Why This Architecture?**
|
| 188 |
+
- **Small Latent Space (8D)**: Captures essential patient patterns without overfitting
|
| 189 |
+
- **Symmetric Design**: Encoder mirrors decoder for balanced learning
|
| 190 |
+
- **Dropout Regularization**: Prevents overfitting on small medical datasets
|
| 191 |
+
- **Medical Constraints**: Output activations ensure realistic medical ranges
|
| 192 |
+
|
| 193 |
+
### Phase 3: Training Process (`train.py`)
|
| 194 |
+
|
| 195 |
+
**Training Loop:**
|
| 196 |
+
```python
|
| 197 |
+
for epoch in range(EPOCHS):
|
| 198 |
+
# Forward pass
|
| 199 |
+
patient_data → encoder → (μ, σ) → sample_z → decoder → reconstructed_patient
|
| 200 |
+
|
| 201 |
+
# Loss calculation
|
| 202 |
+
reconstruction_loss = ||original - reconstructed||²
|
| 203 |
+
kl_loss = KL_divergence(latent_distribution, standard_normal)
|
| 204 |
+
total_loss = reconstruction_loss + β * kl_loss
|
| 205 |
+
|
| 206 |
+
# Optimization
|
| 207 |
+
optimizer.step()
|
| 208 |
+
|
| 209 |
+
# Validation and early stopping
|
| 210 |
+
if validation_loss_improved:
|
| 211 |
+
save_best_model()
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
**Key Training Features:**
|
| 215 |
+
- **Early Stopping**: Prevents overfitting on medical data
|
| 216 |
+
- **Learning Rate Scheduling**: Adapts learning rate based on progress
|
| 217 |
+
- **Gradient Clipping**: Ensures stable training
|
| 218 |
+
- **Medical Validation**: Checks generated data for medical plausibility
|
| 219 |
+
|
| 220 |
+
### Phase 4: Synthetic Data Generation
|
| 221 |
+
|
| 222 |
+
**Generation Process:**
|
| 223 |
+
```python
|
| 224 |
+
# Sample from standard normal distribution
|
| 225 |
+
z ~ N(0, I) # 8-dimensional latent code
|
| 226 |
+
|
| 227 |
+
# Decode to patient features
|
| 228 |
+
synthetic_patient = decoder(z)
|
| 229 |
+
|
| 230 |
+
# Inverse transform to original scale
|
| 231 |
+
real_patient_data = scaler.inverse_transform(synthetic_patient)
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
**Quality Assurance:**
|
| 235 |
+
- **Statistical Validation**: Mean, variance, correlation preservation
|
| 236 |
+
- **Medical Range Checking**: Ensures realistic vital signs
|
| 237 |
+
- **Diversity Metrics**: Prevents mode collapse
|
| 238 |
+
- **Privacy Metrics**: Ensures no direct patient replication
|
| 239 |
+
|
| 240 |
+
## 🔧 API Endpoints
|
| 241 |
+
|
| 242 |
+
### Generate Synthetic Patients
|
| 243 |
+
```http
|
| 244 |
+
POST /generate
|
| 245 |
+
Content-Type: application/json
|
| 246 |
+
|
| 247 |
+
{
|
| 248 |
+
"n_samples": 100,
|
| 249 |
+
"random_seed": 42,
|
| 250 |
+
"temperature": 1.0
|
| 251 |
+
}
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
**Response:**
|
| 255 |
+
```json
|
| 256 |
+
{
|
| 257 |
+
"data": [[patient_features], ...],
|
| 258 |
+
"metadata": {
|
| 259 |
+
"n_samples": 100,
|
| 260 |
+
"latent_dim": 8,
|
| 261 |
+
"features": ["age", "gender", "bmi", ...]
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
### Encode Real Patient
|
| 267 |
+
```http
|
| 268 |
+
POST /encode
|
| 269 |
+
Content-Type: application/json
|
| 270 |
+
|
| 271 |
+
{
|
| 272 |
+
"age": 45,
|
| 273 |
+
"gender": 1,
|
| 274 |
+
"bmi": 28.5,
|
| 275 |
+
"systolic_bp": 140,
|
| 276 |
+
"diabetes": 0,
|
| 277 |
+
"cholesterol": 220,
|
| 278 |
+
"heart_rate": 72
|
| 279 |
+
}
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
### Health Check
|
| 283 |
+
```http
|
| 284 |
+
GET /health
|
| 285 |
+
```
|
| 286 |
+
|
| 287 |
+
## 📈 Model Performance Metrics
|
| 288 |
+
|
| 289 |
+
### Training Metrics
|
| 290 |
+
- **Reconstruction Loss**: How well the model recreates original patients
|
| 291 |
+
- **KL Divergence**: How well the latent space follows normal distribution
|
| 292 |
+
- **Validation Loss**: Generalization performance
|
| 293 |
+
|
| 294 |
+
### Generation Quality Metrics
|
| 295 |
+
- **Statistical Fidelity**: Correlation preservation, distribution matching
|
| 296 |
+
- **Medical Plausibility**: Realistic vital sign ranges, logical relationships
|
| 297 |
+
- **Privacy Protection**: No memorization of training patients
|
| 298 |
+
- **Diversity**: Coverage of different patient types
|
| 299 |
+
|
| 300 |
+
## ⚙️ Configuration
|
| 301 |
+
|
| 302 |
+
### Hyperparameters (`train.py`)
|
| 303 |
+
```python
|
| 304 |
+
BATCH_SIZE = 32 # Optimal for small medical datasets
|
| 305 |
+
LEARNING_RATE = 1e-3 # Conservative for stable training
|
| 306 |
+
EPOCHS = 150 # Sufficient for convergence
|
| 307 |
+
LATENT_DIM = 8 # Captures essential patient variations
|
| 308 |
+
BETA = 1.0 # Balance reconstruction vs. regularization
|
| 309 |
+
```
|
| 310 |
+
|
| 311 |
+
### Model Architecture (`model.py`)
|
| 312 |
+
```python
|
| 313 |
+
INPUT_DIM = 10 # Number of patient features
|
| 314 |
+
HIDDEN_DIMS = [32, 16] # Encoder/decoder layer sizes
|
| 315 |
+
DROPOUT = 0.1 # Regularization strength
|
| 316 |
+
```
|
| 317 |
+
|
| 318 |
+
## 🔒 Privacy and Compliance
|
| 319 |
+
|
| 320 |
+
### Privacy Preservation
|
| 321 |
+
- **No Direct Storage**: Original patients not stored in model
|
| 322 |
+
- **Latent Space Learning**: Model learns patterns, not individuals
|
| 323 |
+
- **Differential Privacy**: Optional noise injection for stronger privacy
|
| 324 |
+
- **Audit Trail**: Generation logging for compliance
|
| 325 |
+
|
| 326 |
+
### HIPAA Compliance Considerations
|
| 327 |
+
- **De-identification**: Remove direct identifiers before training
|
| 328 |
+
- **Access Controls**: Secure API endpoints with authentication
|
| 329 |
+
- **Audit Logging**: Track all data generation requests
|
| 330 |
+
- **Data Minimization**: Only use necessary patient features
|
| 331 |
+
|
| 332 |
+
## 🧪 Testing and Validation
|
| 333 |
+
|
| 334 |
+
### Unit Tests
|
| 335 |
+
```bash
|
| 336 |
+
# Run model tests
|
| 337 |
+
python -m pytest tests/test_model.py
|
| 338 |
+
|
| 339 |
+
# Run API tests
|
| 340 |
+
python -m pytest tests/test_api.py
|
| 341 |
+
```
|
| 342 |
+
|
| 343 |
+
### Manual Validation
|
| 344 |
+
```python
|
| 345 |
+
# Evaluate model performance
|
| 346 |
+
python src/evaluate.py
|
| 347 |
+
|
| 348 |
+
# Check generated data quality
|
| 349 |
+
python -c "
|
| 350 |
+
from src.api import generate_synthetic_data
|
| 351 |
+
data = generate_synthetic_data({'n_samples': 100})
|
| 352 |
+
print('Generated data shape:', len(data['data']), 'x', len(data['data'][0]))
|
| 353 |
+
"
|
| 354 |
+
```
|
| 355 |
+
|
| 356 |
+
## 🚀 Deployment Options
|
| 357 |
+
|
| 358 |
+
### Local Development
|
| 359 |
+
```bash
|
| 360 |
+
uvicorn src.api:app --reload --port 8000
|
| 361 |
+
```
|
| 362 |
+
|
| 363 |
+
### Docker Deployment
|
| 364 |
+
```dockerfile
|
| 365 |
+
FROM python:3.9-slim
|
| 366 |
+
COPY . /app
|
| 367 |
+
WORKDIR /app
|
| 368 |
+
RUN pip install -r requirements.txt
|
| 369 |
+
CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "8000"]
|
| 370 |
+
```
|
| 371 |
+
|
| 372 |
+
### Cloud Deployment
|
| 373 |
+
- **AWS**: ECS, Lambda, or SageMaker
|
| 374 |
+
- **GCP**: Cloud Run, AI Platform
|
| 375 |
+
- **Azure**: Container Instances, ML Service
|
| 376 |
+
|
| 377 |
+
## 📊 Use Cases
|
| 378 |
+
|
| 379 |
+
### Healthcare AI Training
|
| 380 |
+
- **Augment Small Datasets**: Increase training data for rare conditions
|
| 381 |
+
- **Balance Datasets**: Generate underrepresented patient groups
|
| 382 |
+
- **Privacy-Safe Sharing**: Share synthetic data instead of real patients
|
| 383 |
+
- **Model Testing**: Stress-test AI systems with edge cases
|
| 384 |
+
|
| 385 |
+
### Research Applications
|
| 386 |
+
- **Clinical Trial Simulation**: Model patient populations
|
| 387 |
+
- **Treatment Planning**: Explore treatment outcomes
|
| 388 |
+
- **Epidemiological Studies**: Study disease patterns
|
| 389 |
+
- **Health Economics**: Model patient costs and outcomes
|
| 390 |
+
|
| 391 |
+
## 🔮 Future Enhancements
|
| 392 |
+
|
| 393 |
+
### Model Improvements
|
| 394 |
+
- **Conditional Generation**: Generate patients with specific conditions
|
| 395 |
+
- **Temporal Models**: Patient progression over time
|
| 396 |
+
- **Multi-Modal**: Include medical images, text notes
|
| 397 |
+
- **Federated Learning**: Train across multiple hospitals
|
| 398 |
+
|
| 399 |
+
### Technical Enhancements
|
| 400 |
+
- **Real-time Generation**: Streaming synthetic data
|
| 401 |
+
- **Model Monitoring**: Drift detection and retraining
|
| 402 |
+
- **A/B Testing**: Compare different generation strategies
|
| 403 |
+
- **Scalability**: Handle larger datasets and more complex models
|
| 404 |
+
|
| 405 |
+
## 🤝 Contributing
|
| 406 |
+
|
| 407 |
+
1. Fork the repository
|
| 408 |
+
2. Create a feature branch: `git checkout -b feature/new-feature`
|
| 409 |
+
3. Make changes and test thoroughly
|
| 410 |
+
4. Submit a pull request with detailed description
|
| 411 |
+
|
| 412 |
+
## 📄 License
|
| 413 |
+
|
| 414 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 415 |
+
|
| 416 |
+
## ⚠️ Disclaimers
|
| 417 |
+
|
| 418 |
+
- **Research Purpose**: This is for research and development purposes
|
| 419 |
+
- **Medical Advice**: Generated data should not be used for actual medical decisions
|
| 420 |
+
- **Compliance**: Ensure compliance with local healthcare regulations
|
| 421 |
+
- **Validation**: Always validate synthetic data quality for your specific use case
|
| 422 |
+
|
| 423 |
+
## 📞 Support
|
| 424 |
+
|
| 425 |
+
For questions or issues:
|
| 426 |
+
- Create an issue on GitHub
|
| 427 |
+
- Email: theaniketgiri@gmail.com
|
| 428 |
+
|
| 429 |
+
---
|
| 430 |
+
|
| 431 |
+
**Happy Healthcare AI Development! 🏥🤖**
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
pandas
|
| 3 |
+
numpy
|
| 4 |
+
fastapi
|
| 5 |
+
uvicorn
|
| 6 |
+
scikit-learn
|
| 7 |
+
joblib
|
src/__pycache__/api.cpython-311.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
src/__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (4.22 kB). View file
|
|
|
src/api.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/api.py - Enhanced API with better error handling for patient data
|
| 2 |
+
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import joblib
|
| 7 |
+
from src.model import TabularVAE
|
| 8 |
+
from typing import List, Optional, Dict, Any
|
| 9 |
+
import os
|
| 10 |
+
import shutil
|
| 11 |
+
from fastapi.responses import JSONResponse, HTMLResponse
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
app = FastAPI(title="Healthcare VAE API", version="1.0.0")
|
| 15 |
+
|
| 16 |
+
# Load model and scaler
|
| 17 |
+
try:
|
| 18 |
+
# Load feature names and determine input dimension
|
| 19 |
+
if os.path.exists("models/feature_names.pkl"):
|
| 20 |
+
feature_names = joblib.load("models/feature_names.pkl")
|
| 21 |
+
INPUT_DIM = len(feature_names)
|
| 22 |
+
print(f"Loaded {INPUT_DIM} features: {feature_names}")
|
| 23 |
+
else:
|
| 24 |
+
# Fallback to default features
|
| 25 |
+
feature_names = ["age", "gender", "diagnosis", "blood_type", "length_of_stay",
|
| 26 |
+
"age_group", "admission_season", "admission_day", "admission_month", "admission_year"]
|
| 27 |
+
INPUT_DIM = len(feature_names)
|
| 28 |
+
print(f"Using default {INPUT_DIM} features")
|
| 29 |
+
|
| 30 |
+
LATENT_DIM = 8
|
| 31 |
+
|
| 32 |
+
model = TabularVAE(input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden_dims=(32, 16))
|
| 33 |
+
model.load_state_dict(torch.load("models/vae_model.pth", map_location='cpu'))
|
| 34 |
+
model.eval()
|
| 35 |
+
|
| 36 |
+
scaler = joblib.load("models/scaler.pkl")
|
| 37 |
+
|
| 38 |
+
# Load encoders if available
|
| 39 |
+
encoders = None
|
| 40 |
+
if os.path.exists("models/encoders.pkl"):
|
| 41 |
+
encoders = joblib.load("models/encoders.pkl")
|
| 42 |
+
|
| 43 |
+
print("Model and scaler loaded successfully!")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"Error loading model: {e}")
|
| 46 |
+
print("Please run training first!")
|
| 47 |
+
|
| 48 |
+
class GenerateRequest(BaseModel):
|
| 49 |
+
n_samples: int = Field(..., ge=1, le=1000, description="Number of samples to generate")
|
| 50 |
+
random_seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
| 51 |
+
temperature: float = Field(1.0, ge=0.1, le=2.0, description="Sampling temperature")
|
| 52 |
+
|
| 53 |
+
class PatientData(BaseModel):
|
| 54 |
+
age: float = Field(..., ge=0, le=120, description="Patient age")
|
| 55 |
+
gender: str = Field(..., description="Patient gender (Male/Female)")
|
| 56 |
+
diagnosis: str = Field(..., description="Patient diagnosis")
|
| 57 |
+
blood_type: str = Field(..., description="Blood type")
|
| 58 |
+
length_of_stay: Optional[float] = Field(None, description="Length of stay in days")
|
| 59 |
+
age_group: Optional[int] = Field(None, ge=0, le=4, description="Age group (0-4)")
|
| 60 |
+
admission_season: Optional[int] = Field(None, ge=0, le=3, description="Admission season (0-3)")
|
| 61 |
+
admission_day: Optional[int] = Field(None, ge=0, le=6, description="Admission day of week (0-6)")
|
| 62 |
+
admission_month: Optional[int] = Field(None, ge=0, le=11, description="Admission month (0-11)")
|
| 63 |
+
admission_year: Optional[int] = Field(None, description="Admission year (normalized)")
|
| 64 |
+
|
| 65 |
+
class GeneratedResponse(BaseModel):
|
| 66 |
+
data: List[List[float]]
|
| 67 |
+
metadata: dict
|
| 68 |
+
|
| 69 |
+
def convert_numpy_to_python(obj):
|
| 70 |
+
"""Convert numpy types to Python native types for JSON serialization"""
|
| 71 |
+
if isinstance(obj, np.integer):
|
| 72 |
+
return int(obj)
|
| 73 |
+
elif isinstance(obj, np.floating):
|
| 74 |
+
return float(obj)
|
| 75 |
+
elif isinstance(obj, np.ndarray):
|
| 76 |
+
return obj.tolist()
|
| 77 |
+
elif isinstance(obj, list):
|
| 78 |
+
return [convert_numpy_to_python(item) for item in obj]
|
| 79 |
+
elif isinstance(obj, dict):
|
| 80 |
+
return {key: convert_numpy_to_python(value) for key, value in obj.items()}
|
| 81 |
+
else:
|
| 82 |
+
return obj
|
| 83 |
+
|
| 84 |
+
@app.get("/")
|
| 85 |
+
def read_root():
|
| 86 |
+
return {"message": "Healthcare VAE API is running!", "features": feature_names}
|
| 87 |
+
|
| 88 |
+
@app.get("/features")
|
| 89 |
+
def get_features():
|
| 90 |
+
"""Get information about the model features"""
|
| 91 |
+
return {
|
| 92 |
+
"feature_names": feature_names,
|
| 93 |
+
"input_dim": INPUT_DIM,
|
| 94 |
+
"latent_dim": LATENT_DIM
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
@app.post("/generate", response_model=GeneratedResponse)
|
| 98 |
+
def generate_synthetic_data(request: GenerateRequest):
|
| 99 |
+
try:
|
| 100 |
+
if request.random_seed is not None:
|
| 101 |
+
torch.manual_seed(request.random_seed)
|
| 102 |
+
np.random.seed(request.random_seed)
|
| 103 |
+
|
| 104 |
+
# Generate samples
|
| 105 |
+
z = torch.randn(request.n_samples, LATENT_DIM) * request.temperature
|
| 106 |
+
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
samples = model.decode(z).numpy()
|
| 109 |
+
|
| 110 |
+
# Inverse transform to original scale
|
| 111 |
+
data = scaler.inverse_transform(samples).tolist()
|
| 112 |
+
|
| 113 |
+
metadata = {
|
| 114 |
+
"n_samples": request.n_samples,
|
| 115 |
+
"latent_dim": LATENT_DIM,
|
| 116 |
+
"temperature": request.temperature,
|
| 117 |
+
"features": feature_names
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
return {"data": data, "metadata": metadata}
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
| 124 |
+
|
| 125 |
+
@app.post("/encode")
|
| 126 |
+
def encode_patient(patient: PatientData):
|
| 127 |
+
"""Encode patient data to latent space"""
|
| 128 |
+
try:
|
| 129 |
+
# Convert patient data to feature vector
|
| 130 |
+
feature_vector = []
|
| 131 |
+
|
| 132 |
+
# Age
|
| 133 |
+
feature_vector.append(patient.age)
|
| 134 |
+
|
| 135 |
+
# Gender (encode if encoders available)
|
| 136 |
+
if encoders and 'gender' in encoders:
|
| 137 |
+
gender_encoded = encoders['gender'].transform([patient.gender])[0]
|
| 138 |
+
feature_vector.append(gender_encoded)
|
| 139 |
+
else:
|
| 140 |
+
# Fallback encoding
|
| 141 |
+
gender_encoded = 0 if patient.gender.lower() == 'male' else 1
|
| 142 |
+
feature_vector.append(gender_encoded)
|
| 143 |
+
|
| 144 |
+
# Diagnosis (encode if encoders available)
|
| 145 |
+
if encoders and 'diagnosis' in encoders:
|
| 146 |
+
diagnosis_encoded = encoders['diagnosis'].transform([patient.diagnosis])[0]
|
| 147 |
+
feature_vector.append(diagnosis_encoded)
|
| 148 |
+
else:
|
| 149 |
+
# Fallback encoding (simple hash)
|
| 150 |
+
diagnosis_encoded = hash(patient.diagnosis) % 10
|
| 151 |
+
feature_vector.append(diagnosis_encoded)
|
| 152 |
+
|
| 153 |
+
# Blood type (encode if encoders available)
|
| 154 |
+
if encoders and 'blood_type' in encoders:
|
| 155 |
+
blood_encoded = encoders['blood_type'].transform([patient.blood_type])[0]
|
| 156 |
+
feature_vector.append(blood_encoded)
|
| 157 |
+
else:
|
| 158 |
+
# Fallback encoding (simple hash)
|
| 159 |
+
blood_encoded = hash(patient.blood_type) % 8
|
| 160 |
+
feature_vector.append(blood_encoded)
|
| 161 |
+
|
| 162 |
+
# Length of stay
|
| 163 |
+
los = patient.length_of_stay if patient.length_of_stay is not None else 7.0
|
| 164 |
+
feature_vector.append(los)
|
| 165 |
+
|
| 166 |
+
# Age group
|
| 167 |
+
age_group = patient.age_group if patient.age_group is not None else 2
|
| 168 |
+
feature_vector.append(age_group)
|
| 169 |
+
|
| 170 |
+
# Admission season
|
| 171 |
+
season = patient.admission_season if patient.admission_season is not None else 0
|
| 172 |
+
feature_vector.append(season)
|
| 173 |
+
|
| 174 |
+
# Admission day
|
| 175 |
+
day = patient.admission_day if patient.admission_day is not None else 0
|
| 176 |
+
feature_vector.append(day)
|
| 177 |
+
|
| 178 |
+
# Admission month
|
| 179 |
+
month = patient.admission_month if patient.admission_month is not None else 0
|
| 180 |
+
feature_vector.append(month)
|
| 181 |
+
|
| 182 |
+
# Admission year
|
| 183 |
+
year = patient.admission_year if patient.admission_year is not None else 4
|
| 184 |
+
feature_vector.append(year)
|
| 185 |
+
|
| 186 |
+
# Ensure we have the right number of features
|
| 187 |
+
if len(feature_vector) != INPUT_DIM:
|
| 188 |
+
# Pad or truncate to match input dimension
|
| 189 |
+
while len(feature_vector) < INPUT_DIM:
|
| 190 |
+
feature_vector.append(0.0)
|
| 191 |
+
feature_vector = feature_vector[:INPUT_DIM]
|
| 192 |
+
|
| 193 |
+
# Convert to array and scale
|
| 194 |
+
data = np.array([feature_vector])
|
| 195 |
+
scaled_data = scaler.transform(data)
|
| 196 |
+
tensor_data = torch.tensor(scaled_data, dtype=torch.float32)
|
| 197 |
+
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
mu, logvar = model.encode(tensor_data)
|
| 200 |
+
|
| 201 |
+
# Convert numpy types to Python native types for JSON serialization
|
| 202 |
+
response = {
|
| 203 |
+
"latent_mean": convert_numpy_to_python(mu.numpy().tolist()),
|
| 204 |
+
"latent_logvar": convert_numpy_to_python(logvar.numpy().tolist()),
|
| 205 |
+
"features_used": feature_names,
|
| 206 |
+
"feature_values": convert_numpy_to_python(feature_vector)
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
return response
|
| 210 |
+
except Exception as e:
|
| 211 |
+
raise HTTPException(status_code=500, detail=f"Encoding failed: {str(e)}")
|
| 212 |
+
|
| 213 |
+
@app.get("/health")
|
| 214 |
+
def health_check():
|
| 215 |
+
"""Health check endpoint"""
|
| 216 |
+
return {
|
| 217 |
+
"status": "healthy",
|
| 218 |
+
"model_loaded": True,
|
| 219 |
+
"input_dim": INPUT_DIM,
|
| 220 |
+
"latent_dim": LATENT_DIM
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
@app.post("/upload_data")
|
| 224 |
+
async def upload_data(file: UploadFile = File(...)):
|
| 225 |
+
"""Upload a CSV file for continual training."""
|
| 226 |
+
os.makedirs("data", exist_ok=True)
|
| 227 |
+
file_location = "data/new_data.csv"
|
| 228 |
+
with open(file_location, "wb") as buffer:
|
| 229 |
+
shutil.copyfileobj(file.file, buffer)
|
| 230 |
+
return {"status": "success", "filename": file.filename}
|
| 231 |
+
|
| 232 |
+
@app.get("/training_progress")
|
| 233 |
+
def get_training_progress():
|
| 234 |
+
"""Get the latest training progress metrics for the web interface."""
|
| 235 |
+
progress_file = "data/training_progress.json"
|
| 236 |
+
if not os.path.exists(progress_file):
|
| 237 |
+
return JSONResponse(content={"status": "no_progress", "message": "No training progress found."}, status_code=404)
|
| 238 |
+
with open(progress_file, "r") as f:
|
| 239 |
+
progress = json.load(f)
|
| 240 |
+
return JSONResponse(content=progress)
|
| 241 |
+
|
| 242 |
+
@app.get("/dashboard", response_class=HTMLResponse)
|
| 243 |
+
def dashboard():
|
| 244 |
+
html = '''
|
| 245 |
+
<!DOCTYPE html>
|
| 246 |
+
<html lang="en">
|
| 247 |
+
<head>
|
| 248 |
+
<meta charset="UTF-8">
|
| 249 |
+
<title>Training Progress Dashboard</title>
|
| 250 |
+
<style>
|
| 251 |
+
body { font-family: Arial, sans-serif; margin: 2em; background: #f9f9f9; }
|
| 252 |
+
h1 { color: #2c3e50; }
|
| 253 |
+
#progress { background: #fff; padding: 1em; border-radius: 8px; box-shadow: 0 2px 8px #eee; max-width: 400px; }
|
| 254 |
+
.label { color: #888; }
|
| 255 |
+
</style>
|
| 256 |
+
</head>
|
| 257 |
+
<body>
|
| 258 |
+
<h1>Training Progress</h1>
|
| 259 |
+
<div id="progress">
|
| 260 |
+
<div><span class="label">Epoch:</span> <span id="epoch">-</span></div>
|
| 261 |
+
<div><span class="label">Train Loss:</span> <span id="train_loss">-</span></div>
|
| 262 |
+
<div><span class="label">Val Loss:</span> <span id="val_loss">-</span></div>
|
| 263 |
+
<div><span class="label">Best Val Loss:</span> <span id="best_val_loss">-</span></div>
|
| 264 |
+
<div><span class="label">Last Updated:</span> <span id="timestamp">-</span></div>
|
| 265 |
+
</div>
|
| 266 |
+
<script>
|
| 267 |
+
async function fetchProgress() {
|
| 268 |
+
try {
|
| 269 |
+
const res = await fetch('/training_progress');
|
| 270 |
+
if (!res.ok) throw new Error('No progress yet');
|
| 271 |
+
const data = await res.json();
|
| 272 |
+
document.getElementById('epoch').textContent = data.epoch;
|
| 273 |
+
document.getElementById('train_loss').textContent = data.train_loss?.toFixed(4);
|
| 274 |
+
document.getElementById('val_loss').textContent = data.val_loss?.toFixed(4);
|
| 275 |
+
document.getElementById('best_val_loss').textContent = data.best_val_loss?.toFixed(4);
|
| 276 |
+
const date = new Date(data.timestamp * 1000);
|
| 277 |
+
document.getElementById('timestamp').textContent = date.toLocaleString();
|
| 278 |
+
} catch (e) {
|
| 279 |
+
document.getElementById('progress').innerHTML = '<b>No training progress yet.</b>';
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
fetchProgress();
|
| 283 |
+
setInterval(fetchProgress, 3000);
|
| 284 |
+
</script>
|
| 285 |
+
</body>
|
| 286 |
+
</html>
|
| 287 |
+
'''
|
| 288 |
+
return HTMLResponse(content=html)
|
| 289 |
+
|
| 290 |
+
# Run with: uvicorn src.api:app --reload --host 0.0.0.0 --port 8000
|
src/continual_train.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from src.train import train_vae
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
def continual_train(progress_callback=None):
|
| 6 |
+
"""
|
| 7 |
+
Fine-tune the VAE on new data. Optionally log progress via progress_callback.
|
| 8 |
+
"""
|
| 9 |
+
# Assume new data is already in data/new_data.csv and preprocessed
|
| 10 |
+
if not os.path.exists("data/new_data.csv"):
|
| 11 |
+
print("No new data found for continual training.")
|
| 12 |
+
return
|
| 13 |
+
# Optionally, preprocess new data if needed (skipped for simplicity)
|
| 14 |
+
# For now, just retrain on all processed data
|
| 15 |
+
print("Loading all processed data for fine-tuning...")
|
| 16 |
+
if os.path.exists("data/processed_patient_data.csv"):
|
| 17 |
+
feature_df = pd.read_csv("data/processed_patient_data.csv")
|
| 18 |
+
# Optionally, append new data
|
| 19 |
+
new_df = pd.read_csv("data/new_data.csv")
|
| 20 |
+
feature_df = pd.concat([feature_df, new_df], ignore_index=True)
|
| 21 |
+
feature_df.to_csv("data/processed_patient_data.csv", index=False)
|
| 22 |
+
else:
|
| 23 |
+
feature_df = pd.read_csv("data/new_data.csv")
|
| 24 |
+
feature_df.to_csv("data/processed_patient_data.csv", index=False)
|
| 25 |
+
print(f"Fine-tuning on {feature_df.shape[0]} samples...")
|
| 26 |
+
# Call train_vae with progress_callback
|
| 27 |
+
train_vae(progress_callback=progress_callback)
|
src/continual_train_loop.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
def continual_train_loop():
|
| 6 |
+
print("[Continual Training] Loop started. Waiting for new data...")
|
| 7 |
+
while True:
|
| 8 |
+
if os.path.exists("data/new_data.csv"):
|
| 9 |
+
print("[Continual Training] New data found! Fine-tuning model...")
|
| 10 |
+
from src.continual_train import continual_train
|
| 11 |
+
continual_train(progress_callback=log_training_progress)
|
| 12 |
+
os.remove("data/new_data.csv")
|
| 13 |
+
print("[Continual Training] Model updated and new data file removed.")
|
| 14 |
+
time.sleep(60) # Check every 60 seconds
|
| 15 |
+
|
| 16 |
+
def log_training_progress(epoch, train_loss, val_loss, best_val_loss):
|
| 17 |
+
progress = {
|
| 18 |
+
"epoch": epoch,
|
| 19 |
+
"train_loss": train_loss,
|
| 20 |
+
"val_loss": val_loss,
|
| 21 |
+
"best_val_loss": best_val_loss,
|
| 22 |
+
"timestamp": time.time()
|
| 23 |
+
}
|
| 24 |
+
os.makedirs("data", exist_ok=True)
|
| 25 |
+
with open("data/training_progress.json", "w") as f:
|
| 26 |
+
json.dump(progress, f)
|
| 27 |
+
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
continual_train_loop()
|
src/data_preprocessing.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/data_preprocessing.py - Convert patient data to numerical features
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sklearn.preprocessing import LabelEncoder, StandardScaler
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import joblib
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def preprocess_patient_data(csv_file="data/patient_data.csv"):
|
| 10 |
+
"""
|
| 11 |
+
Convert patient CSV data to numerical features for VAE training
|
| 12 |
+
"""
|
| 13 |
+
print("Loading and preprocessing patient data...")
|
| 14 |
+
|
| 15 |
+
# Load data
|
| 16 |
+
df = pd.read_csv(csv_file)
|
| 17 |
+
print(f"Original data shape: {df.shape}")
|
| 18 |
+
|
| 19 |
+
# Create numerical features
|
| 20 |
+
features = {}
|
| 21 |
+
|
| 22 |
+
# 1. Age (already numerical)
|
| 23 |
+
features['age'] = df['Age'].values
|
| 24 |
+
|
| 25 |
+
# 2. Gender (encode: Male=0, Female=1)
|
| 26 |
+
gender_encoder = LabelEncoder()
|
| 27 |
+
features['gender'] = gender_encoder.fit_transform(df['Gender'])
|
| 28 |
+
|
| 29 |
+
# 3. Diagnosis (encode categorical)
|
| 30 |
+
diagnosis_encoder = LabelEncoder()
|
| 31 |
+
features['diagnosis'] = diagnosis_encoder.fit_transform(df['Diagnosis'])
|
| 32 |
+
|
| 33 |
+
# 4. Blood Type (encode categorical)
|
| 34 |
+
blood_encoder = LabelEncoder()
|
| 35 |
+
features['blood_type'] = blood_encoder.fit_transform(df['BloodType'])
|
| 36 |
+
|
| 37 |
+
# 5. Length of stay (calculate from admission/discharge dates)
|
| 38 |
+
df['AdmissionDate'] = pd.to_datetime(df['AdmissionDate'])
|
| 39 |
+
df['DischargeDate'] = pd.to_datetime(df['DischargeDate'])
|
| 40 |
+
features['length_of_stay'] = (df['DischargeDate'] - df['AdmissionDate']).dt.days
|
| 41 |
+
|
| 42 |
+
# 6. Age group (create age categories)
|
| 43 |
+
age_bins = [0, 18, 35, 50, 65, 100]
|
| 44 |
+
age_labels = [0, 1, 2, 3, 4]
|
| 45 |
+
features['age_group'] = pd.cut(df['Age'], bins=age_bins, labels=age_labels, include_lowest=True).astype(int)
|
| 46 |
+
|
| 47 |
+
# 7. Season of admission (extract from admission date)
|
| 48 |
+
features['admission_season'] = df['AdmissionDate'].dt.quarter - 1 # 0=Q1, 1=Q2, 2=Q3, 3=Q4
|
| 49 |
+
|
| 50 |
+
# 8. Day of week admission (0=Monday, 6=Sunday)
|
| 51 |
+
features['admission_day'] = df['AdmissionDate'].dt.dayofweek
|
| 52 |
+
|
| 53 |
+
# 9. Month of admission (0-11)
|
| 54 |
+
features['admission_month'] = df['AdmissionDate'].dt.month - 1
|
| 55 |
+
|
| 56 |
+
# 10. Year of admission (normalized)
|
| 57 |
+
features['admission_year'] = df['AdmissionDate'].dt.year - 2020 # Normalize to 2020 as base
|
| 58 |
+
|
| 59 |
+
# Convert to DataFrame
|
| 60 |
+
feature_df = pd.DataFrame(features)
|
| 61 |
+
|
| 62 |
+
# Handle any missing values
|
| 63 |
+
feature_df = feature_df.fillna(feature_df.mean())
|
| 64 |
+
|
| 65 |
+
print(f"Processed features shape: {feature_df.shape}")
|
| 66 |
+
print("Feature columns:", list(feature_df.columns))
|
| 67 |
+
|
| 68 |
+
# Save encoders for later use
|
| 69 |
+
encoders = {
|
| 70 |
+
'gender': gender_encoder,
|
| 71 |
+
'diagnosis': diagnosis_encoder,
|
| 72 |
+
'blood_type': blood_encoder
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
os.makedirs("models", exist_ok=True)
|
| 76 |
+
joblib.dump(encoders, 'models/encoders.pkl')
|
| 77 |
+
|
| 78 |
+
# Save processed data
|
| 79 |
+
os.makedirs("data", exist_ok=True)
|
| 80 |
+
feature_df.to_csv('data/processed_patient_data.csv', index=False)
|
| 81 |
+
|
| 82 |
+
print("Data preprocessing completed!")
|
| 83 |
+
print(f"Number of features: {feature_df.shape[1]}")
|
| 84 |
+
|
| 85 |
+
return feature_df, encoders
|
| 86 |
+
|
| 87 |
+
def create_sample_data_for_training():
|
| 88 |
+
"""
|
| 89 |
+
Create a sample dataset if the original data is not available
|
| 90 |
+
"""
|
| 91 |
+
print("Creating sample patient data for training...")
|
| 92 |
+
|
| 93 |
+
np.random.seed(42)
|
| 94 |
+
n_samples = 1000
|
| 95 |
+
|
| 96 |
+
# Generate realistic patient data
|
| 97 |
+
data = {
|
| 98 |
+
'age': np.random.normal(50, 20, n_samples).clip(1, 100),
|
| 99 |
+
'gender': np.random.choice([0, 1], n_samples),
|
| 100 |
+
'bmi': np.random.normal(25, 5, n_samples).clip(15, 50),
|
| 101 |
+
'blood_pressure': np.random.normal(120, 20, n_samples).clip(80, 200),
|
| 102 |
+
'diabetes': np.random.choice([0, 1], n_samples, p=[0.8, 0.2]),
|
| 103 |
+
'cholesterol': np.random.normal(200, 40, n_samples).clip(100, 300),
|
| 104 |
+
'heart_rate': np.random.normal(75, 15, n_samples).clip(40, 120)
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
df = pd.DataFrame(data)
|
| 108 |
+
os.makedirs("data", exist_ok=True)
|
| 109 |
+
df.to_csv('data/patient_data.csv', index=False)
|
| 110 |
+
|
| 111 |
+
print(f"Sample data created with {n_samples} patients")
|
| 112 |
+
return df
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
try:
|
| 116 |
+
# Try to preprocess the real data
|
| 117 |
+
feature_df, encoders = preprocess_patient_data()
|
| 118 |
+
print("Successfully processed real patient data!")
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"Error processing real data: {e}")
|
| 121 |
+
print("Creating sample data instead...")
|
| 122 |
+
create_sample_data_for_training()
|
src/model.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Enhanced version with key improvements
|
| 2 |
+
|
| 3 |
+
# model.py - Add validation and better loss
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
class TabularVAE(nn.Module):
|
| 9 |
+
def __init__(self, input_dim: int, hidden_dims=(64, 32), latent_dim=16):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.input_dim = input_dim
|
| 12 |
+
self.latent_dim = latent_dim
|
| 13 |
+
|
| 14 |
+
# Encoder
|
| 15 |
+
dims = [input_dim, *hidden_dims]
|
| 16 |
+
self.encoder_layers = nn.ModuleList([
|
| 17 |
+
nn.Linear(dims[i], dims[i+1]) for i in range(len(dims)-1)
|
| 18 |
+
])
|
| 19 |
+
self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
|
| 20 |
+
self.fc_logvar = nn.Linear(hidden_dims[-1], latent_dim)
|
| 21 |
+
|
| 22 |
+
# Decoder
|
| 23 |
+
dims_rev = [latent_dim, *reversed(hidden_dims)]
|
| 24 |
+
self.decoder_layers = nn.ModuleList([
|
| 25 |
+
nn.Linear(dims_rev[i], dims_rev[i+1]) for i in range(len(dims_rev)-1)
|
| 26 |
+
])
|
| 27 |
+
self.output_layer = nn.Linear(hidden_dims[0], input_dim)
|
| 28 |
+
|
| 29 |
+
# Add dropout for better generalization
|
| 30 |
+
self.dropout = nn.Dropout(0.1)
|
| 31 |
+
|
| 32 |
+
def encode(self, x):
|
| 33 |
+
h = x
|
| 34 |
+
for layer in self.encoder_layers:
|
| 35 |
+
h = F.relu(layer(h))
|
| 36 |
+
h = self.dropout(h)
|
| 37 |
+
mu = self.fc_mu(h)
|
| 38 |
+
logvar = self.fc_logvar(h)
|
| 39 |
+
return mu, logvar
|
| 40 |
+
|
| 41 |
+
def reparameterize(self, mu, logvar):
|
| 42 |
+
std = torch.exp(0.5 * logvar)
|
| 43 |
+
eps = torch.randn_like(std)
|
| 44 |
+
return mu + eps * std
|
| 45 |
+
|
| 46 |
+
def decode(self, z):
|
| 47 |
+
h = z
|
| 48 |
+
for layer in self.decoder_layers:
|
| 49 |
+
h = F.relu(layer(h))
|
| 50 |
+
h = self.dropout(h)
|
| 51 |
+
return self.output_layer(h)
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
mu, logvar = self.encode(x)
|
| 55 |
+
z = self.reparameterize(mu, logvar)
|
| 56 |
+
recon = self.decode(z)
|
| 57 |
+
return recon, mu, logvar
|
src/train.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/train.py - Enhanced training with validation for patient data
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
from sklearn.preprocessing import StandardScaler
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
from src.model import TabularVAE
|
| 10 |
+
import joblib
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
# Hyperparameters
|
| 14 |
+
BATCH_SIZE = 32 # Smaller batch size for smaller dataset
|
| 15 |
+
LR = 1e-3
|
| 16 |
+
EPOCHS = 150 # More epochs for smaller dataset
|
| 17 |
+
LATENT_DIM = 8 # Smaller latent dim for smaller dataset
|
| 18 |
+
BETA = 1.0 # KL divergence weight
|
| 19 |
+
|
| 20 |
+
def vae_loss(recon, x, mu, logvar, beta=1.0):
|
| 21 |
+
"""Enhanced VAE loss with proper normalization"""
|
| 22 |
+
batch_size = x.size(0)
|
| 23 |
+
|
| 24 |
+
# Reconstruction loss (MSE)
|
| 25 |
+
recon_loss = F.mse_loss(recon, x, reduction='sum') / batch_size
|
| 26 |
+
|
| 27 |
+
# KL divergence loss
|
| 28 |
+
kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size
|
| 29 |
+
|
| 30 |
+
return recon_loss + beta * kld, recon_loss, kld
|
| 31 |
+
|
| 32 |
+
def train_vae(progress_callback=None):
|
| 33 |
+
# Check if preprocessed data exists, if not create it
|
| 34 |
+
if not os.path.exists("data/processed_patient_data.csv"):
|
| 35 |
+
print("Preprocessed data not found. Running data preprocessing...")
|
| 36 |
+
from src.data_preprocessing import preprocess_patient_data
|
| 37 |
+
feature_df, encoders = preprocess_patient_data()
|
| 38 |
+
else:
|
| 39 |
+
print("Loading preprocessed data...")
|
| 40 |
+
feature_df = pd.read_csv("data/processed_patient_data.csv")
|
| 41 |
+
|
| 42 |
+
print(f"Dataset shape: {feature_df.shape}")
|
| 43 |
+
print(f"Features: {list(feature_df.columns)}")
|
| 44 |
+
|
| 45 |
+
# Handle missing values
|
| 46 |
+
feature_df = feature_df.fillna(feature_df.mean())
|
| 47 |
+
|
| 48 |
+
# Split data
|
| 49 |
+
train_df, val_df = train_test_split(feature_df, test_size=0.2, random_state=42)
|
| 50 |
+
|
| 51 |
+
# Scale data
|
| 52 |
+
scaler = StandardScaler()
|
| 53 |
+
train_data = scaler.fit_transform(train_df.values)
|
| 54 |
+
val_data = scaler.transform(val_df.values)
|
| 55 |
+
|
| 56 |
+
print(f"Training data shape: {train_data.shape}")
|
| 57 |
+
print(f"Validation data shape: {val_data.shape}")
|
| 58 |
+
|
| 59 |
+
# Create data loaders
|
| 60 |
+
train_tensor = torch.tensor(train_data, dtype=torch.float32)
|
| 61 |
+
val_tensor = torch.tensor(val_data, dtype=torch.float32)
|
| 62 |
+
|
| 63 |
+
train_loader = DataLoader(TensorDataset(train_tensor), batch_size=BATCH_SIZE, shuffle=True)
|
| 64 |
+
val_loader = DataLoader(TensorDataset(val_tensor), batch_size=BATCH_SIZE, shuffle=False)
|
| 65 |
+
|
| 66 |
+
# Initialize model with correct input dimension
|
| 67 |
+
input_dim = train_data.shape[1]
|
| 68 |
+
model = TabularVAE(input_dim=input_dim, latent_dim=LATENT_DIM, hidden_dims=(32, 16))
|
| 69 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
| 70 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=15, factor=0.5)
|
| 71 |
+
|
| 72 |
+
best_val_loss = float('inf')
|
| 73 |
+
patience_counter = 0
|
| 74 |
+
early_stopping_patience = 30
|
| 75 |
+
|
| 76 |
+
print(f"Model initialized with {input_dim} input features and {LATENT_DIM} latent dimensions")
|
| 77 |
+
print(f"Training for {EPOCHS} epochs...")
|
| 78 |
+
|
| 79 |
+
# Training loop
|
| 80 |
+
for epoch in range(EPOCHS):
|
| 81 |
+
# Training
|
| 82 |
+
model.train()
|
| 83 |
+
train_loss = 0
|
| 84 |
+
train_recon = 0
|
| 85 |
+
train_kld = 0
|
| 86 |
+
|
| 87 |
+
for (batch,) in train_loader:
|
| 88 |
+
optimizer.zero_grad()
|
| 89 |
+
recon, mu, logvar = model(batch)
|
| 90 |
+
loss, recon_loss, kld_loss = vae_loss(recon, batch, mu, logvar, BETA)
|
| 91 |
+
loss.backward()
|
| 92 |
+
optimizer.step()
|
| 93 |
+
|
| 94 |
+
train_loss += loss.item()
|
| 95 |
+
train_recon += recon_loss.item()
|
| 96 |
+
train_kld += kld_loss.item()
|
| 97 |
+
|
| 98 |
+
# Validation
|
| 99 |
+
model.eval()
|
| 100 |
+
val_loss = 0
|
| 101 |
+
val_recon = 0
|
| 102 |
+
val_kld = 0
|
| 103 |
+
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
for (batch,) in val_loader:
|
| 106 |
+
recon, mu, logvar = model(batch)
|
| 107 |
+
loss, recon_loss, kld_loss = vae_loss(recon, batch, mu, logvar, BETA)
|
| 108 |
+
|
| 109 |
+
val_loss += loss.item()
|
| 110 |
+
val_recon += recon_loss.item()
|
| 111 |
+
val_kld += kld_loss.item()
|
| 112 |
+
|
| 113 |
+
# Calculate averages
|
| 114 |
+
train_loss /= len(train_loader)
|
| 115 |
+
val_loss /= len(val_loader)
|
| 116 |
+
|
| 117 |
+
# Learning rate scheduling
|
| 118 |
+
scheduler.step(val_loss)
|
| 119 |
+
|
| 120 |
+
# Save best model
|
| 121 |
+
if val_loss < best_val_loss:
|
| 122 |
+
best_val_loss = val_loss
|
| 123 |
+
torch.save(model.state_dict(), "models/best_vae_model.pth")
|
| 124 |
+
patience_counter = 0
|
| 125 |
+
else:
|
| 126 |
+
patience_counter += 1
|
| 127 |
+
|
| 128 |
+
# Early stopping
|
| 129 |
+
if patience_counter >= early_stopping_patience:
|
| 130 |
+
print(f"Early stopping at epoch {epoch+1}")
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
# Print progress
|
| 134 |
+
if epoch % 10 == 0 or epoch == EPOCHS - 1:
|
| 135 |
+
print(f"Epoch {epoch+1}/{EPOCHS}")
|
| 136 |
+
print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
|
| 137 |
+
print(f"Train Recon: {train_recon:.4f}, Train KLD: {train_kld:.4f}")
|
| 138 |
+
print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")
|
| 139 |
+
# Call progress callback if provided
|
| 140 |
+
if progress_callback:
|
| 141 |
+
progress_callback(epoch+1, train_loss, val_loss, best_val_loss)
|
| 142 |
+
|
| 143 |
+
# Save final model and scaler
|
| 144 |
+
torch.save(model.state_dict(), "models/vae_model.pth")
|
| 145 |
+
joblib.dump(scaler, "models/scaler.pkl")
|
| 146 |
+
|
| 147 |
+
# Save feature names for API
|
| 148 |
+
feature_names = list(feature_df.columns)
|
| 149 |
+
joblib.dump(feature_names, "models/feature_names.pkl")
|
| 150 |
+
|
| 151 |
+
print("Training completed!")
|
| 152 |
+
print(f"Best validation loss: {best_val_loss:.4f}")
|
| 153 |
+
print(f"Model saved with {input_dim} input features")
|
| 154 |
+
|
| 155 |
+
return model, scaler, feature_names
|
| 156 |
+
|
| 157 |
+
if __name__ == "__main__":
|
| 158 |
+
model, scaler, feature_names = train_vae()
|
src/web_scraper.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from bs4 import BeautifulSoup
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
def scrape_table_from_url(url, session=None, table_index=0):
|
| 8 |
+
sess = session or requests.Session()
|
| 9 |
+
resp = sess.get(url)
|
| 10 |
+
soup = BeautifulSoup(resp.text, 'html.parser')
|
| 11 |
+
tables = soup.find_all('table')
|
| 12 |
+
if not tables:
|
| 13 |
+
print(f"No tables found at {url}")
|
| 14 |
+
return None
|
| 15 |
+
df = pd.read_html(str(tables[table_index]))[0]
|
| 16 |
+
return df
|
| 17 |
+
|
| 18 |
+
def login_and_get_session(login_url, payload):
|
| 19 |
+
sess = requests.Session()
|
| 20 |
+
resp = sess.post(login_url, data=payload)
|
| 21 |
+
if resp.ok:
|
| 22 |
+
print("Login successful.")
|
| 23 |
+
return sess
|
| 24 |
+
else:
|
| 25 |
+
print("Login failed.")
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
def scrape_multiple_sources(sources, output_csv):
|
| 29 |
+
all_dfs = []
|
| 30 |
+
for src in sources:
|
| 31 |
+
if src.get('login_url'):
|
| 32 |
+
session = login_and_get_session(src['login_url'], src['login_payload'])
|
| 33 |
+
else:
|
| 34 |
+
session = None
|
| 35 |
+
df = scrape_table_from_url(src['url'], session=session, table_index=src.get('table_index', 0))
|
| 36 |
+
if df is not None:
|
| 37 |
+
all_dfs.append(df)
|
| 38 |
+
if all_dfs:
|
| 39 |
+
combined = pd.concat(all_dfs, ignore_index=True)
|
| 40 |
+
os.makedirs('data', exist_ok=True)
|
| 41 |
+
combined.to_csv(output_csv, index=False)
|
| 42 |
+
print(f"Combined data saved to {output_csv}")
|
| 43 |
+
else:
|
| 44 |
+
print("No data scraped.")
|
| 45 |
+
|
| 46 |
+
def main_loop():
|
| 47 |
+
sources = [
|
| 48 |
+
{"url": "https://www.somepublichealthsite.org/table1.html"},
|
| 49 |
+
# Add more sources as needed
|
| 50 |
+
]
|
| 51 |
+
while True:
|
| 52 |
+
print("[Web Scraper] Scraping new data...")
|
| 53 |
+
scrape_multiple_sources(sources, "data/new_data.csv")
|
| 54 |
+
print("[Web Scraper] Waiting 6 hours for next scrape...")
|
| 55 |
+
time.sleep(6 * 60 * 60) # Wait 6 hours
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
main_loop()
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
|
| 3 |
+
BASE_URL = "http://localhost:8000"
|
| 4 |
+
|
| 5 |
+
def test_generate():
|
| 6 |
+
resp = requests.post(f"{BASE_URL}/generate", json={
|
| 7 |
+
"n_samples": 2,
|
| 8 |
+
"temperature": 1.0,
|
| 9 |
+
"random_seed": 42
|
| 10 |
+
})
|
| 11 |
+
print("/generate status:", resp.status_code)
|
| 12 |
+
print("/generate response:", resp.json())
|
| 13 |
+
|
| 14 |
+
def test_encode():
|
| 15 |
+
resp = requests.post(f"{BASE_URL}/encode", json={
|
| 16 |
+
"age": 45,
|
| 17 |
+
"gender": "Male",
|
| 18 |
+
"diagnosis": "Diabetes",
|
| 19 |
+
"blood_type": "A+",
|
| 20 |
+
"length_of_stay": 7
|
| 21 |
+
})
|
| 22 |
+
print("/encode status:", resp.status_code)
|
| 23 |
+
print("/encode response:", resp.json())
|
| 24 |
+
|
| 25 |
+
if __name__ == "__main__":
|
| 26 |
+
test_generate()
|
| 27 |
+
test_encode()
|