theaniketgiri commited on
Commit
902fa1b
·
0 Parent(s):
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . .
6
+
7
+ RUN pip install --upgrade pip && \
8
+ pip install -r requirements.txt && \
9
+ pip install beautifulsoup4 lxml
10
+
11
+ RUN mkdir -p data models
12
+
13
+ EXPOSE 8000
14
+
15
+ COPY docker_entrypoint.sh /app/docker_entrypoint.sh
16
+ RUN chmod +x /app/docker_entrypoint.sh
17
+
18
+ CMD ["/app/docker_entrypoint.sh"]
__pycache__/api.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
__pycache__/model.cpython-311.pyc ADDED
Binary file (4.22 kB). View file
 
data/patient_data.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PatientID,FirstName,LastName,Gender,Age,DateOfBirth,Diagnosis,BloodType,AdmissionDate,DischargeDate
2
+ 1001,Brittney,Davies,Female,32,1992-10-24,Migraine,B+,2024-06-19,2025-05-10
3
+ 1002,Aaron,Johnson,Male,93,1931-08-09,Diabetes,O+,2025-02-09,2025-03-05
4
+ 1003,Nathan,Romero,Male,73,1952-03-22,Asthma,O+,2025-03-15,2025-05-25
5
+ 1004,Corey,Garcia,Male,84,1940-08-21,Pneumonia,AB-,2024-11-09,2024-12-03
6
+ 1005,Brandon,Scott,Male,42,1983-02-20,Diabetes,A-,2024-12-17,2025-02-12
7
+ 1006,Patricia,Bernard,Female,82,1942-08-13,Arthritis,AB+,2025-01-21,2025-05-26
8
+ 1007,Patrick,Sandoval,Male,92,1932-06-26,Heart Disease,AB+,2025-01-09,2025-04-09
9
+ 1008,Heather,Hughes,Female,24,2000-10-18,Heart Disease,O+,2025-01-31,2025-03-10
10
+ 1009,Gina,Kline,Female,35,1989-10-05,Arthritis,AB-,2024-11-22,2025-01-20
11
+ 1010,Cory,Turner,Male,44,1981-01-07,Cancer,B-,2024-10-04,2024-12-07
12
+ 1011,Mary,Ray,Female,44,1981-04-13,Fracture,A-,2024-08-26,2024-11-12
13
+ 1012,Jennifer,Young,Female,80,1944-11-05,Arthritis,B+,2025-06-04,2025-06-25
14
+ 1013,Patricia,Johnson,Female,73,1951-06-19,Fracture,O-,2024-09-03,2024-09-08
15
+ 1014,Matthew,Davis,Male,8,2016-09-17,Arthritis,B+,2025-01-25,2025-05-09
16
+ 1015,Jason,Dodson,Male,62,1962-08-31,COVID-19,A-,2025-03-06,2025-04-27
17
+ 1016,Thomas,Baker,Male,97,1927-07-05,COVID-19,B+,2025-01-09,2025-03-29
18
+ 1017,Leonard,Cochran,Male,82,1943-05-06,Asthma,AB+,2025-02-28,2025-04-04
19
+ 1018,Jessica,Pearson,Female,62,1963-01-01,Heart Disease,B-,2024-07-05,2025-06-11
20
+ 1019,Robin,Johnson,Female,65,1960-05-03,Heart Disease,O-,2025-05-13,2025-05-27
21
+ 1020,Gary,Hill,Male,87,1938-05-31,Fracture,B+,2024-08-27,2025-03-04
22
+ 1021,Kevin,Lee,Male,81,1944-06-06,Pneumonia,O+,2024-10-25,2025-01-09
23
+ 1022,Robert,Powell,Male,68,1956-07-04,Migraine,A-,2025-03-11,2025-03-12
24
+ 1023,Danielle,Wright,Female,42,1983-01-18,Heart Disease,B+,2025-01-20,2025-06-16
25
+ 1024,Hannah,Fields,Female,23,2001-07-25,Fracture,AB-,2025-04-30,2025-06-19
26
+ 1025,Adam,Barton,Male,54,1970-08-11,Migraine,B-,2024-08-15,2024-10-19
27
+ 1026,Paula,Cochran,Female,83,1942-05-14,Diabetes,B+,2024-07-22,2025-02-12
28
+ 1027,Lisa,Christensen,Female,50,1974-10-15,COVID-19,B-,2025-02-24,2025-05-07
29
+ 1028,Joseph,Huff,Male,81,1944-01-04,Fracture,AB+,2025-03-31,2025-06-06
30
+ 1029,Steven,Spears,Male,99,1926-06-15,COVID-19,A-,2024-11-21,2025-01-04
31
+ 1030,Jennifer,Suarez,Female,95,1929-09-11,Asthma,AB+,2024-10-16,2025-06-10
32
+ 1031,Benjamin,Ross,Male,86,1938-07-22,Migraine,AB-,2024-10-01,2025-03-03
33
+ 1032,Michael,Cunningham,Male,70,1954-08-31,Pneumonia,B-,2025-03-09,2025-04-18
34
+ 1033,Tammy,Bullock,Female,39,1986-04-08,Pneumonia,O+,2025-03-16,2025-05-12
35
+ 1034,Nicolas,Harrison,Male,73,1951-11-07,Diabetes,AB+,2024-08-27,2024-09-21
36
+ 1035,Richard,Cortez,Male,44,1980-11-20,Fracture,B+,2025-02-15,2025-04-05
37
+ 1036,Jordan,Hernandez,Female,74,1950-07-01,Pneumonia,B-,2024-09-30,2025-05-06
38
+ 1037,Ralph,Newman,Male,14,2010-08-08,Hypertension,B+,2024-09-25,2025-01-31
39
+ 1038,Renee,Morrison,Female,63,1961-11-14,Pneumonia,A+,2025-03-08,2025-04-13
40
+ 1039,Karen,Clark,Female,91,1933-08-10,Heart Disease,O-,2024-09-30,2025-05-03
41
+ 1040,Lisa,Benson,Female,36,1989-04-04,Asthma,A-,2024-07-19,2025-01-13
42
+ 1041,Randall,May,Male,52,1973-05-06,Pneumonia,AB-,2025-01-13,2025-02-16
43
+ 1042,Michael,Baker,Male,14,2010-08-28,Fracture,B-,2025-02-23,2025-03-12
44
+ 1043,Todd,Harrington,Male,21,2003-10-26,COVID-19,O-,2025-01-18,2025-02-18
45
+ 1044,Terry,Bryan,Male,44,1980-06-22,COVID-19,A+,2025-05-07,2025-05-31
46
+ 1045,Sharon,Smith,Female,83,1942-01-01,Pneumonia,A+,2025-04-01,2025-05-12
47
+ 1046,David,Ray,Male,70,1954-11-26,Diabetes,A-,2024-10-09,2025-04-09
48
+ 1047,Danielle,Dominguez,Female,66,1959-05-25,COVID-19,AB-,2024-09-18,2024-10-08
49
+ 1048,Victoria,Johnson,Female,5,2019-10-05,Hypertension,A+,2025-04-02,2025-06-08
50
+ 1049,Robert,Ferguson,Male,68,1956-06-29,Hypertension,B+,2024-08-19,2025-01-31
51
+ 1050,Brandon,Hall,Male,71,1954-03-05,COVID-19,B-,2024-11-22,2025-05-02
52
+ 1051,Brittany,Bailey,Female,67,1958-01-24,Asthma,A-,2025-02-04,2025-03-07
53
+ 1052,Kyle,Ryan,Male,49,1976-05-07,Migraine,AB-,2024-09-13,2024-11-13
54
+ 1053,Andrew,Smith,Male,91,1933-07-29,Asthma,B-,2024-12-12,2025-02-11
55
+ 1054,Richard,Williams,Male,4,2021-02-14,Pneumonia,O-,2024-09-23,2024-09-30
56
+ 1055,Aaron,Walton,Male,58,1967-04-21,Cancer,B-,2024-09-23,2025-05-31
57
+ 1056,Stephanie,Johnson,Female,44,1980-12-07,Asthma,O+,2024-08-08,2025-01-11
58
+ 1057,Benjamin,Mitchell,Male,56,1968-09-19,Heart Disease,O+,2025-01-14,2025-02-14
59
+ 1058,Jeffrey,Spence,Male,6,2018-09-29,Migraine,B-,2024-08-23,2025-02-01
60
+ 1059,Jamie,Russell,Female,32,1992-09-08,Cancer,A+,2025-04-04,2025-06-26
61
+ 1060,Alexander,Hernandez,Male,37,1987-08-04,Heart Disease,B+,2024-11-18,2025-04-27
62
+ 1061,Nicole,Gibson,Female,54,1970-11-10,Fracture,A+,2025-05-28,2025-06-13
63
+ 1062,Juan,Thompson,Male,11,2014-03-18,Migraine,AB+,2024-12-30,2025-06-10
64
+ 1063,Duane,West,Male,27,1998-02-14,Diabetes,A-,2024-09-22,2024-10-05
65
+ 1064,Natalie,Lee,Female,37,1988-05-04,Hypertension,B+,2025-02-17,2025-05-02
66
+ 1065,James,Liu,Male,37,1987-12-28,Heart Disease,B-,2024-06-27,2024-11-01
67
+ 1066,Pam,Baker,Female,40,1985-05-23,Fracture,B+,2025-02-17,2025-03-27
68
+ 1067,David,Williams,Male,27,1997-09-16,Fracture,AB-,2025-02-12,2025-04-16
69
+ 1068,Mark,Gray,Male,66,1959-01-27,COVID-19,B+,2025-01-26,2025-06-22
70
+ 1069,Jessica,Cannon,Female,98,1926-11-23,Migraine,A-,2024-09-07,2025-05-21
71
+ 1070,Jason,Roberts,Male,28,1997-04-02,Arthritis,AB-,2024-12-28,2025-05-18
72
+ 1071,Robert,Walker,Male,97,1928-03-08,Pneumonia,A-,2025-05-19,2025-06-12
73
+ 1072,Crystal,Williams,Female,22,2002-10-17,Hypertension,O-,2024-09-06,2024-10-25
74
+ 1073,Emily,Thomas,Female,10,2015-05-16,COVID-19,AB-,2024-08-31,2025-03-02
75
+ 1074,Mark,White,Male,46,1979-01-14,Diabetes,AB+,2025-03-30,2025-06-01
76
+ 1075,Jillian,Lucas,Female,47,1977-10-13,Diabetes,A+,2024-10-06,2024-12-23
77
+ 1076,Jeffrey,Nguyen,Male,23,2001-09-06,Arthritis,B+,2025-01-24,2025-04-30
78
+ 1077,Wendy,Nguyen,Female,13,2012-01-20,Cancer,O-,2024-09-18,2025-01-07
79
+ 1078,Lance,Miller,Male,47,1977-09-10,Fracture,AB+,2024-10-12,2025-05-20
80
+ 1079,William,Andrews,Male,69,1955-11-03,COVID-19,O+,2024-06-23,2024-12-19
81
+ 1080,Heather,King,Female,40,1984-10-13,Heart Disease,B-,2025-04-27,2025-05-19
82
+ 1081,Kayla,Fields,Female,25,2000-05-02,Arthritis,AB-,2025-01-04,2025-05-02
83
+ 1082,Sarah,Kline,Female,62,1962-07-04,Hypertension,B+,2024-07-20,2024-09-17
84
+ 1083,Jeremy,Miller,Male,24,2001-06-02,COVID-19,B-,2025-04-08,2025-06-13
85
+ 1084,Melissa,Wallace,Female,70,1955-03-01,Migraine,O+,2024-12-27,2025-05-31
86
+ 1085,John,Weiss,Male,67,1958-03-03,Pneumonia,AB-,2025-05-19,2025-05-26
87
+ 1086,Darren,Herrera,Male,11,2014-04-08,Asthma,A-,2024-07-04,2024-09-18
88
+ 1087,Megan,Kelly,Female,9,2016-03-14,Asthma,A-,2024-06-24,2024-12-10
89
+ 1088,John,Robertson,Male,8,2017-04-20,Heart Disease,O-,2024-08-18,2025-04-24
90
+ 1089,Michelle,Robles,Female,19,2006-06-09,Diabetes,O-,2024-11-29,2025-06-17
91
+ 1090,Lisa,Wright,Female,12,2012-07-14,Cancer,B-,2025-01-21,2025-03-05
92
+ 1091,Melissa,Reynolds,Female,12,2013-03-18,Hypertension,O-,2024-09-29,2024-10-24
93
+ 1092,Kristen,Sanders,Female,48,1976-10-23,Heart Disease,AB+,2025-05-22,2025-06-22
94
+ 1093,Timothy,Short,Male,2,2022-08-08,Heart Disease,B+,2024-12-07,2025-01-12
95
+ 1094,Gene,Greene,Male,91,1933-11-12,Diabetes,AB+,2024-09-11,2024-11-19
96
+ 1095,Jennifer,Brown,Female,2,2023-01-22,Diabetes,A+,2024-07-09,2025-06-17
97
+ 1096,Jeremy,Wright,Male,45,1980-03-29,Cancer,O+,2025-01-25,2025-02-07
98
+ 1097,Patricia,Pierce,Female,58,1967-01-20,COVID-19,A+,2024-12-27,2025-06-16
99
+ 1098,Cindy,Brooks,Female,40,1984-12-06,Asthma,A+,2024-07-19,2025-04-08
100
+ 1099,Seth,Dawson,Male,30,1995-03-05,Diabetes,AB-,2024-12-20,2025-03-30
101
+ 1100,Gene,Kelly,Male,70,1954-08-22,Arthritis,AB-,2024-09-05,2025-02-02
data/processed_patient_data.csv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ age,gender,diagnosis,blood_type,length_of_stay,age_group,admission_season,admission_day,admission_month,admission_year
2
+ 32,0,8,4,325,1,1,2,5,4
3
+ 93,1,4,6,24,4,0,6,1,5
4
+ 73,1,1,6,71,4,0,5,2,5
5
+ 84,1,9,3,24,4,3,5,10,4
6
+ 42,1,4,1,57,2,3,1,11,4
7
+ 82,0,0,2,125,4,0,1,0,5
8
+ 92,1,6,2,90,4,0,3,0,5
9
+ 24,0,6,6,38,1,0,4,0,5
10
+ 35,0,0,3,59,1,3,4,10,4
11
+ 44,1,3,5,64,2,3,4,9,4
12
+ 44,0,5,1,78,2,2,0,7,4
13
+ 80,0,0,4,21,4,1,2,5,5
14
+ 73,0,5,7,5,4,2,1,8,4
15
+ 8,1,0,4,104,0,0,5,0,5
16
+ 62,1,2,1,52,3,0,3,2,5
17
+ 97,1,2,4,79,4,0,3,0,5
18
+ 82,1,1,2,35,4,0,4,1,5
19
+ 62,0,6,5,341,3,2,4,6,4
20
+ 65,0,6,7,14,3,1,1,4,5
21
+ 87,1,5,4,189,4,2,1,7,4
22
+ 81,1,9,6,76,4,3,4,9,4
23
+ 68,1,8,1,1,4,0,1,2,5
24
+ 42,0,6,4,147,2,0,0,0,5
25
+ 23,0,5,3,50,1,1,2,3,5
26
+ 54,1,8,5,65,3,2,3,7,4
27
+ 83,0,4,4,205,4,2,0,6,4
28
+ 50,0,2,5,72,2,0,0,1,5
29
+ 81,1,5,2,67,4,0,0,2,5
30
+ 99,1,2,1,44,4,3,3,10,4
31
+ 95,0,1,2,237,4,3,2,9,4
32
+ 86,1,8,3,153,4,3,1,9,4
33
+ 70,1,9,5,40,4,0,6,2,5
34
+ 39,0,9,6,57,2,0,6,2,5
35
+ 73,1,4,2,25,4,2,1,7,4
36
+ 44,1,5,4,49,2,0,5,1,5
37
+ 74,0,9,5,218,4,2,0,8,4
38
+ 14,1,7,4,128,0,2,2,8,4
39
+ 63,0,9,0,36,3,0,5,2,5
40
+ 91,0,6,7,215,4,2,0,8,4
41
+ 36,0,1,1,178,2,2,4,6,4
42
+ 52,1,9,3,34,3,0,0,0,5
43
+ 14,1,5,5,17,0,0,6,1,5
44
+ 21,1,2,7,31,1,0,5,0,5
45
+ 44,1,2,0,24,2,1,2,4,5
46
+ 83,0,9,0,41,4,1,1,3,5
47
+ 70,1,4,1,182,4,3,2,9,4
48
+ 66,0,2,3,20,4,2,2,8,4
49
+ 5,0,7,0,67,0,1,2,3,5
50
+ 68,1,7,4,165,4,2,0,7,4
51
+ 71,1,2,5,161,4,3,4,10,4
52
+ 67,0,1,1,31,4,0,1,1,5
53
+ 49,1,8,3,61,2,2,4,8,4
54
+ 91,1,1,5,61,4,3,3,11,4
55
+ 4,1,9,7,7,0,2,0,8,4
56
+ 58,1,3,5,250,3,2,0,8,4
57
+ 44,0,1,6,156,2,2,3,7,4
58
+ 56,1,6,6,31,3,0,1,0,5
59
+ 6,1,8,5,162,0,2,4,7,4
60
+ 32,0,3,0,83,1,1,4,3,5
61
+ 37,1,6,4,160,2,3,0,10,4
62
+ 54,0,5,0,16,3,1,2,4,5
63
+ 11,1,8,2,162,0,3,0,11,4
64
+ 27,1,4,1,13,1,2,6,8,4
65
+ 37,0,7,4,74,2,0,0,1,5
66
+ 37,1,6,5,127,2,1,3,5,4
67
+ 40,0,5,4,38,2,0,0,1,5
68
+ 27,1,5,3,63,1,0,2,1,5
69
+ 66,1,2,4,147,4,0,6,0,5
70
+ 98,0,8,1,256,4,2,5,8,4
71
+ 28,1,0,3,141,1,3,5,11,4
72
+ 97,1,9,1,24,4,1,0,4,5
73
+ 22,0,7,7,49,1,2,4,8,4
74
+ 10,0,2,3,183,0,2,5,7,4
75
+ 46,1,4,2,63,2,0,6,2,5
76
+ 47,0,4,0,78,2,3,6,9,4
77
+ 23,1,0,4,96,1,0,4,0,5
78
+ 13,0,3,7,111,0,2,2,8,4
79
+ 47,1,5,2,220,2,3,5,9,4
80
+ 69,1,2,6,179,4,1,6,5,4
81
+ 40,0,6,5,22,2,1,6,3,5
82
+ 25,0,0,3,118,1,0,5,0,5
83
+ 62,0,7,4,59,3,2,5,6,4
84
+ 24,1,2,5,66,1,1,1,3,5
85
+ 70,0,8,6,155,4,3,4,11,4
86
+ 67,1,9,3,7,4,1,0,4,5
87
+ 11,1,1,1,76,0,2,3,6,4
88
+ 9,0,1,1,169,0,1,0,5,4
89
+ 8,1,6,7,249,0,2,6,7,4
90
+ 19,0,4,7,200,1,3,4,10,4
91
+ 12,0,3,5,43,0,0,1,0,5
92
+ 12,0,7,7,25,0,2,6,8,4
93
+ 48,0,6,2,31,2,1,3,4,5
94
+ 2,1,6,4,36,0,3,5,11,4
95
+ 91,1,4,2,69,4,2,2,8,4
96
+ 2,0,4,0,343,0,2,1,6,4
97
+ 45,1,3,6,13,2,0,5,0,5
98
+ 58,0,2,0,171,3,3,4,11,4
99
+ 40,0,1,0,263,2,2,4,6,4
100
+ 30,1,4,3,100,1,3,4,11,4
101
+ 70,1,0,3,150,4,2,3,8,4
docker_entrypoint.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ # Start continual training loop in background
4
+ python src/continual_train_loop.py &
5
+
6
+ # Start web scraper in background
7
+ python src/web_scraper.py &
8
+
9
+ # Start FastAPI server (foreground)
10
+ uvicorn src.api:app --host 0.0.0.0 --port 8000 --reload
models/best_vae_model.pth ADDED
Binary file (13.7 kB). View file
 
models/encoders.pkl ADDED
Binary file (1.1 kB). View file
 
models/feature_names.pkl ADDED
Binary file (155 Bytes). View file
 
models/scaler.pkl ADDED
Binary file (855 Bytes). View file
 
models/vae_model.pth ADDED
Binary file (13.6 kB). View file
 
readme.md ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Healthcare Synthetic Data VAE
3
+ emoji: "🏥"
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: "3.41.2"
8
+ app_file: src/api.py
9
+ pinned: false
10
+ ---
11
+ # Healthcare Synthetic Data Generation using VAE
12
+
13
+ A complete pipeline for generating synthetic healthcare data using Variational Autoencoders (VAE) with FastAPI serving capabilities.
14
+
15
+ ## 🏥 Project Overview
16
+
17
+ This project implements a **Variational Autoencoder (VAE)** to generate synthetic patient data for healthcare AI applications. The system can create realistic patient records while preserving privacy and statistical properties of the original data.
18
+
19
+ ### Key Features
20
+ - 🔬 **Medical Data Generation**: Creates synthetic patient records with realistic correlations
21
+ - 🔒 **Privacy-Preserving**: No direct storage of original patient data
22
+ - 🚀 **Production Ready**: FastAPI deployment with RESTful endpoints
23
+ - 📊 **Quality Validation**: Built-in data quality metrics and evaluation
24
+ - 🎛️ **Configurable**: Easy hyperparameter tuning and model customization
25
+
26
+ ## 🏗️ Project Structure
27
+
28
+ ```
29
+ healthcare-vae/
30
+ ├── README.md # This file
31
+ ├── requirements.txt # Python dependencies
32
+ ├── data/
33
+ │ ├── raw_patient_data.csv # Original patient data (not included)
34
+ │ └── processed_patient_data.csv # Preprocessed features
35
+ ├── src/
36
+ │ ├── model.py # VAE architecture
37
+ │ ├── data_preprocessing.py # Data cleaning and feature engineering
38
+ │ ├── train.py # Model training script
39
+ │ ├── evaluate.py # Model evaluation and metrics
40
+ │ └── api.py # FastAPI serving application
41
+ ├── models/
42
+ │ ├── vae_model.pth # Trained VAE weights
43
+ │ ├── best_vae_model.pth # Best model checkpoint
44
+ │ ├── scaler.pkl # Data preprocessing scaler
45
+ │ └── feature_names.pkl # Feature column names
46
+ ├── notebooks/
47
+ │ ├── data_exploration.ipynb # Data analysis and visualization
48
+ │ └── model_analysis.ipynb # Model performance analysis
49
+ └── tests/
50
+ ├── test_model.py # Model unit tests
51
+ └── test_api.py # API endpoint tests
52
+ ```
53
+
54
+ ## 🧠 How VAE Works for Healthcare Data
55
+
56
+ ### Mathematical Foundation
57
+
58
+ **Variational Autoencoder (VAE)** learns a compressed representation of patient data:
59
+
60
+ ```
61
+ Patient Data → Encoder → Latent Space (μ, σ) → Decoder → Synthetic Patient
62
+ ```
63
+
64
+ **Key Components:**
65
+ 1. **Encoder**: Maps patient features to latent space parameters
66
+ 2. **Latent Space**: Continuous representation of patient "types"
67
+ 3. **Decoder**: Generates new patients from latent codes
68
+ 4. **Loss Function**: Reconstruction + KL Divergence
69
+
70
+ ### Training Process
71
+
72
+ ```mermaid
73
+ graph TD
74
+ A[Raw Patient Data] --> B[Data Preprocessing]
75
+ B --> C[Feature Engineering]
76
+ C --> D[Train/Validation Split]
77
+ D --> E[VAE Training]
78
+ E --> F[Model Validation]
79
+ F --> G{Good Performance?}
80
+ G -->|No| H[Adjust Hyperparameters]
81
+ H --> E
82
+ G -->|Yes| I[Save Best Model]
83
+ I --> J[Deploy API]
84
+ ```
85
+
86
+ ## 🚀 Quick Start
87
+
88
+ ### 1. Installation
89
+
90
+ ```bash
91
+ # Clone repository
92
+ git clone https://github.com/theaniketgiri/healthcare-vae.git
93
+ cd healthcare-vae
94
+
95
+ # Install dependencies
96
+ pip install -r requirements.txt
97
+
98
+ # Create necessary directories
99
+ mkdir -p data models notebooks tests
100
+ ```
101
+
102
+ ### 2. Data Preparation
103
+
104
+ ```python
105
+ # Create sample patient data (for demonstration)
106
+ python -c "
107
+ import pandas as pd
108
+ import numpy as np
109
+
110
+ np.random.seed(42)
111
+ n_patients = 1000
112
+
113
+ # Generate synthetic patient data
114
+ data = {
115
+ 'patient_id': range(1, n_patients + 1),
116
+ 'age': np.random.normal(50, 15, n_patients).clip(18, 90),
117
+ 'gender': np.random.choice(['M', 'F'], n_patients),
118
+ 'bmi': np.random.normal(25, 5, n_patients).clip(15, 50),
119
+ 'systolic_bp': np.random.normal(120, 20, n_patients).clip(80, 200),
120
+ 'diastolic_bp': np.random.normal(80, 15, n_patients).clip(50, 120),
121
+ 'diabetes': np.random.choice([0, 1], n_patients, p=[0.8, 0.2]),
122
+ 'cholesterol': np.random.normal(200, 40, n_patients).clip(100, 400),
123
+ 'heart_rate': np.random.normal(72, 12, n_patients).clip(50, 120),
124
+ 'smoking': np.random.choice([0, 1], n_patients, p=[0.7, 0.3]),
125
+ 'family_history': np.random.choice([0, 1], n_patients, p=[0.6, 0.4])
126
+ }
127
+
128
+ df = pd.DataFrame(data)
129
+ df.to_csv('data/raw_patient_data.csv', index=False)
130
+ print('Sample data created!')
131
+ "
132
+ ```
133
+
134
+ ### 3. Train the Model
135
+
136
+ ```bash
137
+ # Preprocess data and train VAE
138
+ python src/train.py
139
+ ```
140
+
141
+ ### 4. Start the API
142
+
143
+ ```bash
144
+ # Launch FastAPI server
145
+ uvicorn src.api:app --reload --host 0.0.0.0 --port 8000
146
+ ```
147
+
148
+ ### 5. Generate Synthetic Data
149
+
150
+ ```bash
151
+ # Test the API
152
+ curl -X POST "http://localhost:8000/generate" \
153
+ -H "Content-Type: application/json" \
154
+ -d '{"n_samples": 10, "random_seed": 42}'
155
+ ```
156
+
157
+ ## 📊 Data Flow Explanation
158
+
159
+ ### Phase 1: Data Preprocessing (`data_preprocessing.py`)
160
+
161
+ ```python
162
+ Raw Patient Data → Feature Engineering → Normalization → Processed Data
163
+ ```
164
+
165
+ **Operations:**
166
+ - **Missing Value Handling**: Imputation strategies for clinical data
167
+ - **Categorical Encoding**: One-hot encoding for gender, diagnosis codes
168
+ - **Feature Scaling**: StandardScaler for numerical stability
169
+ - **Outlier Detection**: Medical range validation
170
+ - **Feature Engineering**: BMI categories, age groups, risk scores
171
+
172
+ ### Phase 2: Model Architecture (`model.py`)
173
+
174
+ **VAE Architecture for Healthcare:**
175
+ ```
176
+ Input Layer (n_features)
177
+
178
+ Encoder Hidden Layers (64→32→16)
179
+
180
+ Latent Space (μ, σ) - 8 dimensions
181
+
182
+ Decoder Hidden Layers (16→32→64)
183
+
184
+ Output Layer (n_features)
185
+ ```
186
+
187
+ **Why This Architecture?**
188
+ - **Small Latent Space (8D)**: Captures essential patient patterns without overfitting
189
+ - **Symmetric Design**: Encoder mirrors decoder for balanced learning
190
+ - **Dropout Regularization**: Prevents overfitting on small medical datasets
191
+ - **Medical Constraints**: Output activations ensure realistic medical ranges
192
+
193
+ ### Phase 3: Training Process (`train.py`)
194
+
195
+ **Training Loop:**
196
+ ```python
197
+ for epoch in range(EPOCHS):
198
+ # Forward pass
199
+ patient_data → encoder → (μ, σ) → sample_z → decoder → reconstructed_patient
200
+
201
+ # Loss calculation
202
+ reconstruction_loss = ||original - reconstructed||²
203
+ kl_loss = KL_divergence(latent_distribution, standard_normal)
204
+ total_loss = reconstruction_loss + β * kl_loss
205
+
206
+ # Optimization
207
+ optimizer.step()
208
+
209
+ # Validation and early stopping
210
+ if validation_loss_improved:
211
+ save_best_model()
212
+ ```
213
+
214
+ **Key Training Features:**
215
+ - **Early Stopping**: Prevents overfitting on medical data
216
+ - **Learning Rate Scheduling**: Adapts learning rate based on progress
217
+ - **Gradient Clipping**: Ensures stable training
218
+ - **Medical Validation**: Checks generated data for medical plausibility
219
+
220
+ ### Phase 4: Synthetic Data Generation
221
+
222
+ **Generation Process:**
223
+ ```python
224
+ # Sample from standard normal distribution
225
+ z ~ N(0, I) # 8-dimensional latent code
226
+
227
+ # Decode to patient features
228
+ synthetic_patient = decoder(z)
229
+
230
+ # Inverse transform to original scale
231
+ real_patient_data = scaler.inverse_transform(synthetic_patient)
232
+ ```
233
+
234
+ **Quality Assurance:**
235
+ - **Statistical Validation**: Mean, variance, correlation preservation
236
+ - **Medical Range Checking**: Ensures realistic vital signs
237
+ - **Diversity Metrics**: Prevents mode collapse
238
+ - **Privacy Metrics**: Ensures no direct patient replication
239
+
240
+ ## 🔧 API Endpoints
241
+
242
+ ### Generate Synthetic Patients
243
+ ```http
244
+ POST /generate
245
+ Content-Type: application/json
246
+
247
+ {
248
+ "n_samples": 100,
249
+ "random_seed": 42,
250
+ "temperature": 1.0
251
+ }
252
+ ```
253
+
254
+ **Response:**
255
+ ```json
256
+ {
257
+ "data": [[patient_features], ...],
258
+ "metadata": {
259
+ "n_samples": 100,
260
+ "latent_dim": 8,
261
+ "features": ["age", "gender", "bmi", ...]
262
+ }
263
+ }
264
+ ```
265
+
266
+ ### Encode Real Patient
267
+ ```http
268
+ POST /encode
269
+ Content-Type: application/json
270
+
271
+ {
272
+ "age": 45,
273
+ "gender": 1,
274
+ "bmi": 28.5,
275
+ "systolic_bp": 140,
276
+ "diabetes": 0,
277
+ "cholesterol": 220,
278
+ "heart_rate": 72
279
+ }
280
+ ```
281
+
282
+ ### Health Check
283
+ ```http
284
+ GET /health
285
+ ```
286
+
287
+ ## 📈 Model Performance Metrics
288
+
289
+ ### Training Metrics
290
+ - **Reconstruction Loss**: How well the model recreates original patients
291
+ - **KL Divergence**: How well the latent space follows normal distribution
292
+ - **Validation Loss**: Generalization performance
293
+
294
+ ### Generation Quality Metrics
295
+ - **Statistical Fidelity**: Correlation preservation, distribution matching
296
+ - **Medical Plausibility**: Realistic vital sign ranges, logical relationships
297
+ - **Privacy Protection**: No memorization of training patients
298
+ - **Diversity**: Coverage of different patient types
299
+
300
+ ## ⚙️ Configuration
301
+
302
+ ### Hyperparameters (`train.py`)
303
+ ```python
304
+ BATCH_SIZE = 32 # Optimal for small medical datasets
305
+ LEARNING_RATE = 1e-3 # Conservative for stable training
306
+ EPOCHS = 150 # Sufficient for convergence
307
+ LATENT_DIM = 8 # Captures essential patient variations
308
+ BETA = 1.0 # Balance reconstruction vs. regularization
309
+ ```
310
+
311
+ ### Model Architecture (`model.py`)
312
+ ```python
313
+ INPUT_DIM = 10 # Number of patient features
314
+ HIDDEN_DIMS = [32, 16] # Encoder/decoder layer sizes
315
+ DROPOUT = 0.1 # Regularization strength
316
+ ```
317
+
318
+ ## 🔒 Privacy and Compliance
319
+
320
+ ### Privacy Preservation
321
+ - **No Direct Storage**: Original patients not stored in model
322
+ - **Latent Space Learning**: Model learns patterns, not individuals
323
+ - **Differential Privacy**: Optional noise injection for stronger privacy
324
+ - **Audit Trail**: Generation logging for compliance
325
+
326
+ ### HIPAA Compliance Considerations
327
+ - **De-identification**: Remove direct identifiers before training
328
+ - **Access Controls**: Secure API endpoints with authentication
329
+ - **Audit Logging**: Track all data generation requests
330
+ - **Data Minimization**: Only use necessary patient features
331
+
332
+ ## 🧪 Testing and Validation
333
+
334
+ ### Unit Tests
335
+ ```bash
336
+ # Run model tests
337
+ python -m pytest tests/test_model.py
338
+
339
+ # Run API tests
340
+ python -m pytest tests/test_api.py
341
+ ```
342
+
343
+ ### Manual Validation
344
+ ```python
345
+ # Evaluate model performance
346
+ python src/evaluate.py
347
+
348
+ # Check generated data quality
349
+ python -c "
350
+ from src.api import generate_synthetic_data
351
+ data = generate_synthetic_data({'n_samples': 100})
352
+ print('Generated data shape:', len(data['data']), 'x', len(data['data'][0]))
353
+ "
354
+ ```
355
+
356
+ ## 🚀 Deployment Options
357
+
358
+ ### Local Development
359
+ ```bash
360
+ uvicorn src.api:app --reload --port 8000
361
+ ```
362
+
363
+ ### Docker Deployment
364
+ ```dockerfile
365
+ FROM python:3.9-slim
366
+ COPY . /app
367
+ WORKDIR /app
368
+ RUN pip install -r requirements.txt
369
+ CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "8000"]
370
+ ```
371
+
372
+ ### Cloud Deployment
373
+ - **AWS**: ECS, Lambda, or SageMaker
374
+ - **GCP**: Cloud Run, AI Platform
375
+ - **Azure**: Container Instances, ML Service
376
+
377
+ ## 📊 Use Cases
378
+
379
+ ### Healthcare AI Training
380
+ - **Augment Small Datasets**: Increase training data for rare conditions
381
+ - **Balance Datasets**: Generate underrepresented patient groups
382
+ - **Privacy-Safe Sharing**: Share synthetic data instead of real patients
383
+ - **Model Testing**: Stress-test AI systems with edge cases
384
+
385
+ ### Research Applications
386
+ - **Clinical Trial Simulation**: Model patient populations
387
+ - **Treatment Planning**: Explore treatment outcomes
388
+ - **Epidemiological Studies**: Study disease patterns
389
+ - **Health Economics**: Model patient costs and outcomes
390
+
391
+ ## 🔮 Future Enhancements
392
+
393
+ ### Model Improvements
394
+ - **Conditional Generation**: Generate patients with specific conditions
395
+ - **Temporal Models**: Patient progression over time
396
+ - **Multi-Modal**: Include medical images, text notes
397
+ - **Federated Learning**: Train across multiple hospitals
398
+
399
+ ### Technical Enhancements
400
+ - **Real-time Generation**: Streaming synthetic data
401
+ - **Model Monitoring**: Drift detection and retraining
402
+ - **A/B Testing**: Compare different generation strategies
403
+ - **Scalability**: Handle larger datasets and more complex models
404
+
405
+ ## 🤝 Contributing
406
+
407
+ 1. Fork the repository
408
+ 2. Create a feature branch: `git checkout -b feature/new-feature`
409
+ 3. Make changes and test thoroughly
410
+ 4. Submit a pull request with detailed description
411
+
412
+ ## 📄 License
413
+
414
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
415
+
416
+ ## ⚠️ Disclaimers
417
+
418
+ - **Research Purpose**: This is for research and development purposes
419
+ - **Medical Advice**: Generated data should not be used for actual medical decisions
420
+ - **Compliance**: Ensure compliance with local healthcare regulations
421
+ - **Validation**: Always validate synthetic data quality for your specific use case
422
+
423
+ ## 📞 Support
424
+
425
+ For questions or issues:
426
+ - Create an issue on GitHub
427
+ - Email: theaniketgiri@gmail.com
428
+
429
+ ---
430
+
431
+ **Happy Healthcare AI Development! 🏥🤖**
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ pandas
3
+ numpy
4
+ fastapi
5
+ uvicorn
6
+ scikit-learn
7
+ joblib
src/__pycache__/api.cpython-311.pyc ADDED
Binary file (17.1 kB). View file
 
src/__pycache__/model.cpython-311.pyc ADDED
Binary file (4.22 kB). View file
 
src/api.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/api.py - Enhanced API with better error handling for patient data
2
+ from fastapi import FastAPI, HTTPException, File, UploadFile
3
+ from pydantic import BaseModel, Field
4
+ import torch
5
+ import numpy as np
6
+ import joblib
7
+ from src.model import TabularVAE
8
+ from typing import List, Optional, Dict, Any
9
+ import os
10
+ import shutil
11
+ from fastapi.responses import JSONResponse, HTMLResponse
12
+ import json
13
+
14
+ app = FastAPI(title="Healthcare VAE API", version="1.0.0")
15
+
16
+ # Load model and scaler
17
+ try:
18
+ # Load feature names and determine input dimension
19
+ if os.path.exists("models/feature_names.pkl"):
20
+ feature_names = joblib.load("models/feature_names.pkl")
21
+ INPUT_DIM = len(feature_names)
22
+ print(f"Loaded {INPUT_DIM} features: {feature_names}")
23
+ else:
24
+ # Fallback to default features
25
+ feature_names = ["age", "gender", "diagnosis", "blood_type", "length_of_stay",
26
+ "age_group", "admission_season", "admission_day", "admission_month", "admission_year"]
27
+ INPUT_DIM = len(feature_names)
28
+ print(f"Using default {INPUT_DIM} features")
29
+
30
+ LATENT_DIM = 8
31
+
32
+ model = TabularVAE(input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden_dims=(32, 16))
33
+ model.load_state_dict(torch.load("models/vae_model.pth", map_location='cpu'))
34
+ model.eval()
35
+
36
+ scaler = joblib.load("models/scaler.pkl")
37
+
38
+ # Load encoders if available
39
+ encoders = None
40
+ if os.path.exists("models/encoders.pkl"):
41
+ encoders = joblib.load("models/encoders.pkl")
42
+
43
+ print("Model and scaler loaded successfully!")
44
+ except Exception as e:
45
+ print(f"Error loading model: {e}")
46
+ print("Please run training first!")
47
+
48
+ class GenerateRequest(BaseModel):
49
+ n_samples: int = Field(..., ge=1, le=1000, description="Number of samples to generate")
50
+ random_seed: Optional[int] = Field(None, description="Random seed for reproducibility")
51
+ temperature: float = Field(1.0, ge=0.1, le=2.0, description="Sampling temperature")
52
+
53
+ class PatientData(BaseModel):
54
+ age: float = Field(..., ge=0, le=120, description="Patient age")
55
+ gender: str = Field(..., description="Patient gender (Male/Female)")
56
+ diagnosis: str = Field(..., description="Patient diagnosis")
57
+ blood_type: str = Field(..., description="Blood type")
58
+ length_of_stay: Optional[float] = Field(None, description="Length of stay in days")
59
+ age_group: Optional[int] = Field(None, ge=0, le=4, description="Age group (0-4)")
60
+ admission_season: Optional[int] = Field(None, ge=0, le=3, description="Admission season (0-3)")
61
+ admission_day: Optional[int] = Field(None, ge=0, le=6, description="Admission day of week (0-6)")
62
+ admission_month: Optional[int] = Field(None, ge=0, le=11, description="Admission month (0-11)")
63
+ admission_year: Optional[int] = Field(None, description="Admission year (normalized)")
64
+
65
+ class GeneratedResponse(BaseModel):
66
+ data: List[List[float]]
67
+ metadata: dict
68
+
69
+ def convert_numpy_to_python(obj):
70
+ """Convert numpy types to Python native types for JSON serialization"""
71
+ if isinstance(obj, np.integer):
72
+ return int(obj)
73
+ elif isinstance(obj, np.floating):
74
+ return float(obj)
75
+ elif isinstance(obj, np.ndarray):
76
+ return obj.tolist()
77
+ elif isinstance(obj, list):
78
+ return [convert_numpy_to_python(item) for item in obj]
79
+ elif isinstance(obj, dict):
80
+ return {key: convert_numpy_to_python(value) for key, value in obj.items()}
81
+ else:
82
+ return obj
83
+
84
+ @app.get("/")
85
+ def read_root():
86
+ return {"message": "Healthcare VAE API is running!", "features": feature_names}
87
+
88
+ @app.get("/features")
89
+ def get_features():
90
+ """Get information about the model features"""
91
+ return {
92
+ "feature_names": feature_names,
93
+ "input_dim": INPUT_DIM,
94
+ "latent_dim": LATENT_DIM
95
+ }
96
+
97
+ @app.post("/generate", response_model=GeneratedResponse)
98
+ def generate_synthetic_data(request: GenerateRequest):
99
+ try:
100
+ if request.random_seed is not None:
101
+ torch.manual_seed(request.random_seed)
102
+ np.random.seed(request.random_seed)
103
+
104
+ # Generate samples
105
+ z = torch.randn(request.n_samples, LATENT_DIM) * request.temperature
106
+
107
+ with torch.no_grad():
108
+ samples = model.decode(z).numpy()
109
+
110
+ # Inverse transform to original scale
111
+ data = scaler.inverse_transform(samples).tolist()
112
+
113
+ metadata = {
114
+ "n_samples": request.n_samples,
115
+ "latent_dim": LATENT_DIM,
116
+ "temperature": request.temperature,
117
+ "features": feature_names
118
+ }
119
+
120
+ return {"data": data, "metadata": metadata}
121
+
122
+ except Exception as e:
123
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
124
+
125
+ @app.post("/encode")
126
+ def encode_patient(patient: PatientData):
127
+ """Encode patient data to latent space"""
128
+ try:
129
+ # Convert patient data to feature vector
130
+ feature_vector = []
131
+
132
+ # Age
133
+ feature_vector.append(patient.age)
134
+
135
+ # Gender (encode if encoders available)
136
+ if encoders and 'gender' in encoders:
137
+ gender_encoded = encoders['gender'].transform([patient.gender])[0]
138
+ feature_vector.append(gender_encoded)
139
+ else:
140
+ # Fallback encoding
141
+ gender_encoded = 0 if patient.gender.lower() == 'male' else 1
142
+ feature_vector.append(gender_encoded)
143
+
144
+ # Diagnosis (encode if encoders available)
145
+ if encoders and 'diagnosis' in encoders:
146
+ diagnosis_encoded = encoders['diagnosis'].transform([patient.diagnosis])[0]
147
+ feature_vector.append(diagnosis_encoded)
148
+ else:
149
+ # Fallback encoding (simple hash)
150
+ diagnosis_encoded = hash(patient.diagnosis) % 10
151
+ feature_vector.append(diagnosis_encoded)
152
+
153
+ # Blood type (encode if encoders available)
154
+ if encoders and 'blood_type' in encoders:
155
+ blood_encoded = encoders['blood_type'].transform([patient.blood_type])[0]
156
+ feature_vector.append(blood_encoded)
157
+ else:
158
+ # Fallback encoding (simple hash)
159
+ blood_encoded = hash(patient.blood_type) % 8
160
+ feature_vector.append(blood_encoded)
161
+
162
+ # Length of stay
163
+ los = patient.length_of_stay if patient.length_of_stay is not None else 7.0
164
+ feature_vector.append(los)
165
+
166
+ # Age group
167
+ age_group = patient.age_group if patient.age_group is not None else 2
168
+ feature_vector.append(age_group)
169
+
170
+ # Admission season
171
+ season = patient.admission_season if patient.admission_season is not None else 0
172
+ feature_vector.append(season)
173
+
174
+ # Admission day
175
+ day = patient.admission_day if patient.admission_day is not None else 0
176
+ feature_vector.append(day)
177
+
178
+ # Admission month
179
+ month = patient.admission_month if patient.admission_month is not None else 0
180
+ feature_vector.append(month)
181
+
182
+ # Admission year
183
+ year = patient.admission_year if patient.admission_year is not None else 4
184
+ feature_vector.append(year)
185
+
186
+ # Ensure we have the right number of features
187
+ if len(feature_vector) != INPUT_DIM:
188
+ # Pad or truncate to match input dimension
189
+ while len(feature_vector) < INPUT_DIM:
190
+ feature_vector.append(0.0)
191
+ feature_vector = feature_vector[:INPUT_DIM]
192
+
193
+ # Convert to array and scale
194
+ data = np.array([feature_vector])
195
+ scaled_data = scaler.transform(data)
196
+ tensor_data = torch.tensor(scaled_data, dtype=torch.float32)
197
+
198
+ with torch.no_grad():
199
+ mu, logvar = model.encode(tensor_data)
200
+
201
+ # Convert numpy types to Python native types for JSON serialization
202
+ response = {
203
+ "latent_mean": convert_numpy_to_python(mu.numpy().tolist()),
204
+ "latent_logvar": convert_numpy_to_python(logvar.numpy().tolist()),
205
+ "features_used": feature_names,
206
+ "feature_values": convert_numpy_to_python(feature_vector)
207
+ }
208
+
209
+ return response
210
+ except Exception as e:
211
+ raise HTTPException(status_code=500, detail=f"Encoding failed: {str(e)}")
212
+
213
+ @app.get("/health")
214
+ def health_check():
215
+ """Health check endpoint"""
216
+ return {
217
+ "status": "healthy",
218
+ "model_loaded": True,
219
+ "input_dim": INPUT_DIM,
220
+ "latent_dim": LATENT_DIM
221
+ }
222
+
223
+ @app.post("/upload_data")
224
+ async def upload_data(file: UploadFile = File(...)):
225
+ """Upload a CSV file for continual training."""
226
+ os.makedirs("data", exist_ok=True)
227
+ file_location = "data/new_data.csv"
228
+ with open(file_location, "wb") as buffer:
229
+ shutil.copyfileobj(file.file, buffer)
230
+ return {"status": "success", "filename": file.filename}
231
+
232
+ @app.get("/training_progress")
233
+ def get_training_progress():
234
+ """Get the latest training progress metrics for the web interface."""
235
+ progress_file = "data/training_progress.json"
236
+ if not os.path.exists(progress_file):
237
+ return JSONResponse(content={"status": "no_progress", "message": "No training progress found."}, status_code=404)
238
+ with open(progress_file, "r") as f:
239
+ progress = json.load(f)
240
+ return JSONResponse(content=progress)
241
+
242
+ @app.get("/dashboard", response_class=HTMLResponse)
243
+ def dashboard():
244
+ html = '''
245
+ <!DOCTYPE html>
246
+ <html lang="en">
247
+ <head>
248
+ <meta charset="UTF-8">
249
+ <title>Training Progress Dashboard</title>
250
+ <style>
251
+ body { font-family: Arial, sans-serif; margin: 2em; background: #f9f9f9; }
252
+ h1 { color: #2c3e50; }
253
+ #progress { background: #fff; padding: 1em; border-radius: 8px; box-shadow: 0 2px 8px #eee; max-width: 400px; }
254
+ .label { color: #888; }
255
+ </style>
256
+ </head>
257
+ <body>
258
+ <h1>Training Progress</h1>
259
+ <div id="progress">
260
+ <div><span class="label">Epoch:</span> <span id="epoch">-</span></div>
261
+ <div><span class="label">Train Loss:</span> <span id="train_loss">-</span></div>
262
+ <div><span class="label">Val Loss:</span> <span id="val_loss">-</span></div>
263
+ <div><span class="label">Best Val Loss:</span> <span id="best_val_loss">-</span></div>
264
+ <div><span class="label">Last Updated:</span> <span id="timestamp">-</span></div>
265
+ </div>
266
+ <script>
267
+ async function fetchProgress() {
268
+ try {
269
+ const res = await fetch('/training_progress');
270
+ if (!res.ok) throw new Error('No progress yet');
271
+ const data = await res.json();
272
+ document.getElementById('epoch').textContent = data.epoch;
273
+ document.getElementById('train_loss').textContent = data.train_loss?.toFixed(4);
274
+ document.getElementById('val_loss').textContent = data.val_loss?.toFixed(4);
275
+ document.getElementById('best_val_loss').textContent = data.best_val_loss?.toFixed(4);
276
+ const date = new Date(data.timestamp * 1000);
277
+ document.getElementById('timestamp').textContent = date.toLocaleString();
278
+ } catch (e) {
279
+ document.getElementById('progress').innerHTML = '<b>No training progress yet.</b>';
280
+ }
281
+ }
282
+ fetchProgress();
283
+ setInterval(fetchProgress, 3000);
284
+ </script>
285
+ </body>
286
+ </html>
287
+ '''
288
+ return HTMLResponse(content=html)
289
+
290
+ # Run with: uvicorn src.api:app --reload --host 0.0.0.0 --port 8000
src/continual_train.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from src.train import train_vae
3
+ import pandas as pd
4
+
5
+ def continual_train(progress_callback=None):
6
+ """
7
+ Fine-tune the VAE on new data. Optionally log progress via progress_callback.
8
+ """
9
+ # Assume new data is already in data/new_data.csv and preprocessed
10
+ if not os.path.exists("data/new_data.csv"):
11
+ print("No new data found for continual training.")
12
+ return
13
+ # Optionally, preprocess new data if needed (skipped for simplicity)
14
+ # For now, just retrain on all processed data
15
+ print("Loading all processed data for fine-tuning...")
16
+ if os.path.exists("data/processed_patient_data.csv"):
17
+ feature_df = pd.read_csv("data/processed_patient_data.csv")
18
+ # Optionally, append new data
19
+ new_df = pd.read_csv("data/new_data.csv")
20
+ feature_df = pd.concat([feature_df, new_df], ignore_index=True)
21
+ feature_df.to_csv("data/processed_patient_data.csv", index=False)
22
+ else:
23
+ feature_df = pd.read_csv("data/new_data.csv")
24
+ feature_df.to_csv("data/processed_patient_data.csv", index=False)
25
+ print(f"Fine-tuning on {feature_df.shape[0]} samples...")
26
+ # Call train_vae with progress_callback
27
+ train_vae(progress_callback=progress_callback)
src/continual_train_loop.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import json
4
+
5
+ def continual_train_loop():
6
+ print("[Continual Training] Loop started. Waiting for new data...")
7
+ while True:
8
+ if os.path.exists("data/new_data.csv"):
9
+ print("[Continual Training] New data found! Fine-tuning model...")
10
+ from src.continual_train import continual_train
11
+ continual_train(progress_callback=log_training_progress)
12
+ os.remove("data/new_data.csv")
13
+ print("[Continual Training] Model updated and new data file removed.")
14
+ time.sleep(60) # Check every 60 seconds
15
+
16
+ def log_training_progress(epoch, train_loss, val_loss, best_val_loss):
17
+ progress = {
18
+ "epoch": epoch,
19
+ "train_loss": train_loss,
20
+ "val_loss": val_loss,
21
+ "best_val_loss": best_val_loss,
22
+ "timestamp": time.time()
23
+ }
24
+ os.makedirs("data", exist_ok=True)
25
+ with open("data/training_progress.json", "w") as f:
26
+ json.dump(progress, f)
27
+
28
+ if __name__ == "__main__":
29
+ continual_train_loop()
src/data_preprocessing.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/data_preprocessing.py - Convert patient data to numerical features
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
5
+ from datetime import datetime
6
+ import joblib
7
+ import os
8
+
9
+ def preprocess_patient_data(csv_file="data/patient_data.csv"):
10
+ """
11
+ Convert patient CSV data to numerical features for VAE training
12
+ """
13
+ print("Loading and preprocessing patient data...")
14
+
15
+ # Load data
16
+ df = pd.read_csv(csv_file)
17
+ print(f"Original data shape: {df.shape}")
18
+
19
+ # Create numerical features
20
+ features = {}
21
+
22
+ # 1. Age (already numerical)
23
+ features['age'] = df['Age'].values
24
+
25
+ # 2. Gender (encode: Male=0, Female=1)
26
+ gender_encoder = LabelEncoder()
27
+ features['gender'] = gender_encoder.fit_transform(df['Gender'])
28
+
29
+ # 3. Diagnosis (encode categorical)
30
+ diagnosis_encoder = LabelEncoder()
31
+ features['diagnosis'] = diagnosis_encoder.fit_transform(df['Diagnosis'])
32
+
33
+ # 4. Blood Type (encode categorical)
34
+ blood_encoder = LabelEncoder()
35
+ features['blood_type'] = blood_encoder.fit_transform(df['BloodType'])
36
+
37
+ # 5. Length of stay (calculate from admission/discharge dates)
38
+ df['AdmissionDate'] = pd.to_datetime(df['AdmissionDate'])
39
+ df['DischargeDate'] = pd.to_datetime(df['DischargeDate'])
40
+ features['length_of_stay'] = (df['DischargeDate'] - df['AdmissionDate']).dt.days
41
+
42
+ # 6. Age group (create age categories)
43
+ age_bins = [0, 18, 35, 50, 65, 100]
44
+ age_labels = [0, 1, 2, 3, 4]
45
+ features['age_group'] = pd.cut(df['Age'], bins=age_bins, labels=age_labels, include_lowest=True).astype(int)
46
+
47
+ # 7. Season of admission (extract from admission date)
48
+ features['admission_season'] = df['AdmissionDate'].dt.quarter - 1 # 0=Q1, 1=Q2, 2=Q3, 3=Q4
49
+
50
+ # 8. Day of week admission (0=Monday, 6=Sunday)
51
+ features['admission_day'] = df['AdmissionDate'].dt.dayofweek
52
+
53
+ # 9. Month of admission (0-11)
54
+ features['admission_month'] = df['AdmissionDate'].dt.month - 1
55
+
56
+ # 10. Year of admission (normalized)
57
+ features['admission_year'] = df['AdmissionDate'].dt.year - 2020 # Normalize to 2020 as base
58
+
59
+ # Convert to DataFrame
60
+ feature_df = pd.DataFrame(features)
61
+
62
+ # Handle any missing values
63
+ feature_df = feature_df.fillna(feature_df.mean())
64
+
65
+ print(f"Processed features shape: {feature_df.shape}")
66
+ print("Feature columns:", list(feature_df.columns))
67
+
68
+ # Save encoders for later use
69
+ encoders = {
70
+ 'gender': gender_encoder,
71
+ 'diagnosis': diagnosis_encoder,
72
+ 'blood_type': blood_encoder
73
+ }
74
+
75
+ os.makedirs("models", exist_ok=True)
76
+ joblib.dump(encoders, 'models/encoders.pkl')
77
+
78
+ # Save processed data
79
+ os.makedirs("data", exist_ok=True)
80
+ feature_df.to_csv('data/processed_patient_data.csv', index=False)
81
+
82
+ print("Data preprocessing completed!")
83
+ print(f"Number of features: {feature_df.shape[1]}")
84
+
85
+ return feature_df, encoders
86
+
87
+ def create_sample_data_for_training():
88
+ """
89
+ Create a sample dataset if the original data is not available
90
+ """
91
+ print("Creating sample patient data for training...")
92
+
93
+ np.random.seed(42)
94
+ n_samples = 1000
95
+
96
+ # Generate realistic patient data
97
+ data = {
98
+ 'age': np.random.normal(50, 20, n_samples).clip(1, 100),
99
+ 'gender': np.random.choice([0, 1], n_samples),
100
+ 'bmi': np.random.normal(25, 5, n_samples).clip(15, 50),
101
+ 'blood_pressure': np.random.normal(120, 20, n_samples).clip(80, 200),
102
+ 'diabetes': np.random.choice([0, 1], n_samples, p=[0.8, 0.2]),
103
+ 'cholesterol': np.random.normal(200, 40, n_samples).clip(100, 300),
104
+ 'heart_rate': np.random.normal(75, 15, n_samples).clip(40, 120)
105
+ }
106
+
107
+ df = pd.DataFrame(data)
108
+ os.makedirs("data", exist_ok=True)
109
+ df.to_csv('data/patient_data.csv', index=False)
110
+
111
+ print(f"Sample data created with {n_samples} patients")
112
+ return df
113
+
114
+ if __name__ == "__main__":
115
+ try:
116
+ # Try to preprocess the real data
117
+ feature_df, encoders = preprocess_patient_data()
118
+ print("Successfully processed real patient data!")
119
+ except Exception as e:
120
+ print(f"Error processing real data: {e}")
121
+ print("Creating sample data instead...")
122
+ create_sample_data_for_training()
src/model.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Enhanced version with key improvements
2
+
3
+ # model.py - Add validation and better loss
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ class TabularVAE(nn.Module):
9
+ def __init__(self, input_dim: int, hidden_dims=(64, 32), latent_dim=16):
10
+ super().__init__()
11
+ self.input_dim = input_dim
12
+ self.latent_dim = latent_dim
13
+
14
+ # Encoder
15
+ dims = [input_dim, *hidden_dims]
16
+ self.encoder_layers = nn.ModuleList([
17
+ nn.Linear(dims[i], dims[i+1]) for i in range(len(dims)-1)
18
+ ])
19
+ self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
20
+ self.fc_logvar = nn.Linear(hidden_dims[-1], latent_dim)
21
+
22
+ # Decoder
23
+ dims_rev = [latent_dim, *reversed(hidden_dims)]
24
+ self.decoder_layers = nn.ModuleList([
25
+ nn.Linear(dims_rev[i], dims_rev[i+1]) for i in range(len(dims_rev)-1)
26
+ ])
27
+ self.output_layer = nn.Linear(hidden_dims[0], input_dim)
28
+
29
+ # Add dropout for better generalization
30
+ self.dropout = nn.Dropout(0.1)
31
+
32
+ def encode(self, x):
33
+ h = x
34
+ for layer in self.encoder_layers:
35
+ h = F.relu(layer(h))
36
+ h = self.dropout(h)
37
+ mu = self.fc_mu(h)
38
+ logvar = self.fc_logvar(h)
39
+ return mu, logvar
40
+
41
+ def reparameterize(self, mu, logvar):
42
+ std = torch.exp(0.5 * logvar)
43
+ eps = torch.randn_like(std)
44
+ return mu + eps * std
45
+
46
+ def decode(self, z):
47
+ h = z
48
+ for layer in self.decoder_layers:
49
+ h = F.relu(layer(h))
50
+ h = self.dropout(h)
51
+ return self.output_layer(h)
52
+
53
+ def forward(self, x):
54
+ mu, logvar = self.encode(x)
55
+ z = self.reparameterize(mu, logvar)
56
+ recon = self.decode(z)
57
+ return recon, mu, logvar
src/train.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/train.py - Enhanced training with validation for patient data
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import DataLoader, TensorDataset
5
+ import pandas as pd
6
+ import numpy as np
7
+ from sklearn.preprocessing import StandardScaler
8
+ from sklearn.model_selection import train_test_split
9
+ from src.model import TabularVAE
10
+ import joblib
11
+ import os
12
+
13
+ # Hyperparameters
14
+ BATCH_SIZE = 32 # Smaller batch size for smaller dataset
15
+ LR = 1e-3
16
+ EPOCHS = 150 # More epochs for smaller dataset
17
+ LATENT_DIM = 8 # Smaller latent dim for smaller dataset
18
+ BETA = 1.0 # KL divergence weight
19
+
20
+ def vae_loss(recon, x, mu, logvar, beta=1.0):
21
+ """Enhanced VAE loss with proper normalization"""
22
+ batch_size = x.size(0)
23
+
24
+ # Reconstruction loss (MSE)
25
+ recon_loss = F.mse_loss(recon, x, reduction='sum') / batch_size
26
+
27
+ # KL divergence loss
28
+ kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size
29
+
30
+ return recon_loss + beta * kld, recon_loss, kld
31
+
32
+ def train_vae(progress_callback=None):
33
+ # Check if preprocessed data exists, if not create it
34
+ if not os.path.exists("data/processed_patient_data.csv"):
35
+ print("Preprocessed data not found. Running data preprocessing...")
36
+ from src.data_preprocessing import preprocess_patient_data
37
+ feature_df, encoders = preprocess_patient_data()
38
+ else:
39
+ print("Loading preprocessed data...")
40
+ feature_df = pd.read_csv("data/processed_patient_data.csv")
41
+
42
+ print(f"Dataset shape: {feature_df.shape}")
43
+ print(f"Features: {list(feature_df.columns)}")
44
+
45
+ # Handle missing values
46
+ feature_df = feature_df.fillna(feature_df.mean())
47
+
48
+ # Split data
49
+ train_df, val_df = train_test_split(feature_df, test_size=0.2, random_state=42)
50
+
51
+ # Scale data
52
+ scaler = StandardScaler()
53
+ train_data = scaler.fit_transform(train_df.values)
54
+ val_data = scaler.transform(val_df.values)
55
+
56
+ print(f"Training data shape: {train_data.shape}")
57
+ print(f"Validation data shape: {val_data.shape}")
58
+
59
+ # Create data loaders
60
+ train_tensor = torch.tensor(train_data, dtype=torch.float32)
61
+ val_tensor = torch.tensor(val_data, dtype=torch.float32)
62
+
63
+ train_loader = DataLoader(TensorDataset(train_tensor), batch_size=BATCH_SIZE, shuffle=True)
64
+ val_loader = DataLoader(TensorDataset(val_tensor), batch_size=BATCH_SIZE, shuffle=False)
65
+
66
+ # Initialize model with correct input dimension
67
+ input_dim = train_data.shape[1]
68
+ model = TabularVAE(input_dim=input_dim, latent_dim=LATENT_DIM, hidden_dims=(32, 16))
69
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR)
70
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=15, factor=0.5)
71
+
72
+ best_val_loss = float('inf')
73
+ patience_counter = 0
74
+ early_stopping_patience = 30
75
+
76
+ print(f"Model initialized with {input_dim} input features and {LATENT_DIM} latent dimensions")
77
+ print(f"Training for {EPOCHS} epochs...")
78
+
79
+ # Training loop
80
+ for epoch in range(EPOCHS):
81
+ # Training
82
+ model.train()
83
+ train_loss = 0
84
+ train_recon = 0
85
+ train_kld = 0
86
+
87
+ for (batch,) in train_loader:
88
+ optimizer.zero_grad()
89
+ recon, mu, logvar = model(batch)
90
+ loss, recon_loss, kld_loss = vae_loss(recon, batch, mu, logvar, BETA)
91
+ loss.backward()
92
+ optimizer.step()
93
+
94
+ train_loss += loss.item()
95
+ train_recon += recon_loss.item()
96
+ train_kld += kld_loss.item()
97
+
98
+ # Validation
99
+ model.eval()
100
+ val_loss = 0
101
+ val_recon = 0
102
+ val_kld = 0
103
+
104
+ with torch.no_grad():
105
+ for (batch,) in val_loader:
106
+ recon, mu, logvar = model(batch)
107
+ loss, recon_loss, kld_loss = vae_loss(recon, batch, mu, logvar, BETA)
108
+
109
+ val_loss += loss.item()
110
+ val_recon += recon_loss.item()
111
+ val_kld += kld_loss.item()
112
+
113
+ # Calculate averages
114
+ train_loss /= len(train_loader)
115
+ val_loss /= len(val_loader)
116
+
117
+ # Learning rate scheduling
118
+ scheduler.step(val_loss)
119
+
120
+ # Save best model
121
+ if val_loss < best_val_loss:
122
+ best_val_loss = val_loss
123
+ torch.save(model.state_dict(), "models/best_vae_model.pth")
124
+ patience_counter = 0
125
+ else:
126
+ patience_counter += 1
127
+
128
+ # Early stopping
129
+ if patience_counter >= early_stopping_patience:
130
+ print(f"Early stopping at epoch {epoch+1}")
131
+ break
132
+
133
+ # Print progress
134
+ if epoch % 10 == 0 or epoch == EPOCHS - 1:
135
+ print(f"Epoch {epoch+1}/{EPOCHS}")
136
+ print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
137
+ print(f"Train Recon: {train_recon:.4f}, Train KLD: {train_kld:.4f}")
138
+ print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")
139
+ # Call progress callback if provided
140
+ if progress_callback:
141
+ progress_callback(epoch+1, train_loss, val_loss, best_val_loss)
142
+
143
+ # Save final model and scaler
144
+ torch.save(model.state_dict(), "models/vae_model.pth")
145
+ joblib.dump(scaler, "models/scaler.pkl")
146
+
147
+ # Save feature names for API
148
+ feature_names = list(feature_df.columns)
149
+ joblib.dump(feature_names, "models/feature_names.pkl")
150
+
151
+ print("Training completed!")
152
+ print(f"Best validation loss: {best_val_loss:.4f}")
153
+ print(f"Model saved with {input_dim} input features")
154
+
155
+ return model, scaler, feature_names
156
+
157
+ if __name__ == "__main__":
158
+ model, scaler, feature_names = train_vae()
src/web_scraper.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import pandas as pd
3
+ from bs4 import BeautifulSoup
4
+ import os
5
+ import time
6
+
7
+ def scrape_table_from_url(url, session=None, table_index=0):
8
+ sess = session or requests.Session()
9
+ resp = sess.get(url)
10
+ soup = BeautifulSoup(resp.text, 'html.parser')
11
+ tables = soup.find_all('table')
12
+ if not tables:
13
+ print(f"No tables found at {url}")
14
+ return None
15
+ df = pd.read_html(str(tables[table_index]))[0]
16
+ return df
17
+
18
+ def login_and_get_session(login_url, payload):
19
+ sess = requests.Session()
20
+ resp = sess.post(login_url, data=payload)
21
+ if resp.ok:
22
+ print("Login successful.")
23
+ return sess
24
+ else:
25
+ print("Login failed.")
26
+ return None
27
+
28
+ def scrape_multiple_sources(sources, output_csv):
29
+ all_dfs = []
30
+ for src in sources:
31
+ if src.get('login_url'):
32
+ session = login_and_get_session(src['login_url'], src['login_payload'])
33
+ else:
34
+ session = None
35
+ df = scrape_table_from_url(src['url'], session=session, table_index=src.get('table_index', 0))
36
+ if df is not None:
37
+ all_dfs.append(df)
38
+ if all_dfs:
39
+ combined = pd.concat(all_dfs, ignore_index=True)
40
+ os.makedirs('data', exist_ok=True)
41
+ combined.to_csv(output_csv, index=False)
42
+ print(f"Combined data saved to {output_csv}")
43
+ else:
44
+ print("No data scraped.")
45
+
46
+ def main_loop():
47
+ sources = [
48
+ {"url": "https://www.somepublichealthsite.org/table1.html"},
49
+ # Add more sources as needed
50
+ ]
51
+ while True:
52
+ print("[Web Scraper] Scraping new data...")
53
+ scrape_multiple_sources(sources, "data/new_data.csv")
54
+ print("[Web Scraper] Waiting 6 hours for next scrape...")
55
+ time.sleep(6 * 60 * 60) # Wait 6 hours
56
+
57
+ if __name__ == "__main__":
58
+ main_loop()
tests/test_api.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ BASE_URL = "http://localhost:8000"
4
+
5
+ def test_generate():
6
+ resp = requests.post(f"{BASE_URL}/generate", json={
7
+ "n_samples": 2,
8
+ "temperature": 1.0,
9
+ "random_seed": 42
10
+ })
11
+ print("/generate status:", resp.status_code)
12
+ print("/generate response:", resp.json())
13
+
14
+ def test_encode():
15
+ resp = requests.post(f"{BASE_URL}/encode", json={
16
+ "age": 45,
17
+ "gender": "Male",
18
+ "diagnosis": "Diabetes",
19
+ "blood_type": "A+",
20
+ "length_of_stay": 7
21
+ })
22
+ print("/encode status:", resp.status_code)
23
+ print("/encode response:", resp.json())
24
+
25
+ if __name__ == "__main__":
26
+ test_generate()
27
+ test_encode()