diff --git a/.ipynb_checkpoints/requirements-checkpoint.txt b/.ipynb_checkpoints/requirements-checkpoint.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b21550486dfb28e440ab2a45993d1a52b9f1e901
Binary files /dev/null and b/.ipynb_checkpoints/requirements-checkpoint.txt differ
diff --git a/data/BTK_-_100A___Aortic_Dissection.docx b/data/BTK_-_100A___Aortic_Dissection.docx
new file mode 100644
index 0000000000000000000000000000000000000000..4ffad254537ca90b01ac884cd0102896373f00f1
Binary files /dev/null and b/data/BTK_-_100A___Aortic_Dissection.docx differ
diff --git a/data/BTK_-_10A___Pancreatic_Neoplasm.docx b/data/BTK_-_10A___Pancreatic_Neoplasm.docx
new file mode 100644
index 0000000000000000000000000000000000000000..373445a82fdef6a2d60df644040c193cd7dac084
Binary files /dev/null and b/data/BTK_-_10A___Pancreatic_Neoplasm.docx differ
diff --git a/data/BTK_-_11A___Pancreatic_Adenocarcinoma.docx b/data/BTK_-_11A___Pancreatic_Adenocarcinoma.docx
new file mode 100644
index 0000000000000000000000000000000000000000..6c7d2f6a84ac29838fe4d57d6fcf41451335762a
Binary files /dev/null and b/data/BTK_-_11A___Pancreatic_Adenocarcinoma.docx differ
diff --git a/data/BTK_-_12A___Pancreatic_Mass_Endocrine.docx b/data/BTK_-_12A___Pancreatic_Mass_Endocrine.docx
new file mode 100644
index 0000000000000000000000000000000000000000..f4f0322a6a00805d08879401931274ae82aaca73
Binary files /dev/null and b/data/BTK_-_12A___Pancreatic_Mass_Endocrine.docx differ
diff --git a/data/BTK_-_13A___Pancreatitis.docx b/data/BTK_-_13A___Pancreatitis.docx
new file mode 100644
index 0000000000000000000000000000000000000000..ace1642013b48d303cf772315c178848de5de64e
Binary files /dev/null and b/data/BTK_-_13A___Pancreatitis.docx differ
diff --git a/data/BTK_-_14A___Esophageal_Cancer.docx b/data/BTK_-_14A___Esophageal_Cancer.docx
new file mode 100644
index 0000000000000000000000000000000000000000..514d9c208d0a5a178a2482abbde02a8a7b1cf742
Binary files /dev/null and b/data/BTK_-_14A___Esophageal_Cancer.docx differ
diff --git a/data/BTK_-_15A___Esophageal_Dysmotility.docx b/data/BTK_-_15A___Esophageal_Dysmotility.docx
new file mode 100644
index 0000000000000000000000000000000000000000..b6f906269d7cef31c07b8683a7e8ec3c3b96cc1b
Binary files /dev/null and b/data/BTK_-_15A___Esophageal_Dysmotility.docx differ
diff --git a/data/BTK_-_16A___Esophageal_Perforation.docx b/data/BTK_-_16A___Esophageal_Perforation.docx
new file mode 100644
index 0000000000000000000000000000000000000000..2cfadf4f2aa63731f9be46579b147165dcceb111
Binary files /dev/null and b/data/BTK_-_16A___Esophageal_Perforation.docx differ
diff --git a/data/BTK_-_17A____Gastric_Cancer.docx b/data/BTK_-_17A____Gastric_Cancer.docx
new file mode 100644
index 0000000000000000000000000000000000000000..7dfe379fd10f32015c0a1e154b0c6261fe68c275
Binary files /dev/null and b/data/BTK_-_17A____Gastric_Cancer.docx differ
diff --git a/data/BTK_-_18A___Gastrointestinal_Stromal_Tumor_GIST.docx b/data/BTK_-_18A___Gastrointestinal_Stromal_Tumor_GIST.docx
new file mode 100644
index 0000000000000000000000000000000000000000..9733301550809f52c5a010a9f20c379e366ec9c9
Binary files /dev/null and b/data/BTK_-_18A___Gastrointestinal_Stromal_Tumor_GIST.docx differ
diff --git a/data/BTK_-_19A___Bariatric.docx b/data/BTK_-_19A___Bariatric.docx
new file mode 100644
index 0000000000000000000000000000000000000000..2d1924bf1d886576d399327437d648b8d42f6726
Binary files /dev/null and b/data/BTK_-_19A___Bariatric.docx differ
diff --git a/data/BTK_-_1A___Rectus_Sheath_Hematoma_.docx b/data/BTK_-_1A___Rectus_Sheath_Hematoma_.docx
new file mode 100644
index 0000000000000000000000000000000000000000..934b08ca0105be807332107b8ec338620b8741a4
Binary files /dev/null and b/data/BTK_-_1A___Rectus_Sheath_Hematoma_.docx differ
diff --git a/data/BTK_-_20A___Upper_GI_Bleed.docx b/data/BTK_-_20A___Upper_GI_Bleed.docx
new file mode 100644
index 0000000000000000000000000000000000000000..fa7f0e13673aa80bda8ceaa44686beea9af9d4fe
Binary files /dev/null and b/data/BTK_-_20A___Upper_GI_Bleed.docx differ
diff --git a/data/BTK_-_21A___Acute_Mesenteric_Ischemia_.docx b/data/BTK_-_21A___Acute_Mesenteric_Ischemia_.docx
new file mode 100644
index 0000000000000000000000000000000000000000..0a6ed50c5406d300a4686b656aa49325a7d6a857
Binary files /dev/null and b/data/BTK_-_21A___Acute_Mesenteric_Ischemia_.docx differ
diff --git a/data/BTK_-_22A___EC_Fistula.docx b/data/BTK_-_22A___EC_Fistula.docx
new file mode 100644
index 0000000000000000000000000000000000000000..bc59ae252530f829e6414f7c5283975ce158b79f
Binary files /dev/null and b/data/BTK_-_22A___EC_Fistula.docx differ
diff --git a/data/BTK_-_23A___Small_Bowel_Obstruction.docx b/data/BTK_-_23A___Small_Bowel_Obstruction.docx
new file mode 100644
index 0000000000000000000000000000000000000000..19cf45c701cbb9eb1be6d002126b68768f1c2b1c
Binary files /dev/null and b/data/BTK_-_23A___Small_Bowel_Obstruction.docx differ
diff --git a/data/BTK_-_24A___Colon_Cancer.docx b/data/BTK_-_24A___Colon_Cancer.docx
new file mode 100644
index 0000000000000000000000000000000000000000..6fd8d60a4f5a0e09af64f4e4e11d7e95ebf8fde9
Binary files /dev/null and b/data/BTK_-_24A___Colon_Cancer.docx differ
diff --git a/data/BTK_-_25A___Inflammatory_Bowel_Disease.docx b/data/BTK_-_25A___Inflammatory_Bowel_Disease.docx
new file mode 100644
index 0000000000000000000000000000000000000000..23e596735da41860b21e47c806f1642bbcba2b31
Binary files /dev/null and b/data/BTK_-_25A___Inflammatory_Bowel_Disease.docx differ
diff --git a/data/BTK_-_26A___Ischemic_Colitis.docx b/data/BTK_-_26A___Ischemic_Colitis.docx
new file mode 100644
index 0000000000000000000000000000000000000000..b89dcb9a51a6013455d25291ccf89acb439b1fff
Binary files /dev/null and b/data/BTK_-_26A___Ischemic_Colitis.docx differ
diff --git a/data/BTK_-_27A___C_Diff.docx b/data/BTK_-_27A___C_Diff.docx
new file mode 100644
index 0000000000000000000000000000000000000000..c786dec7369e8d98ee0c9b1758dc588c5490af39
Binary files /dev/null and b/data/BTK_-_27A___C_Diff.docx differ
diff --git a/data/BTK_-_28A___Volvulus.docx b/data/BTK_-_28A___Volvulus.docx
new file mode 100644
index 0000000000000000000000000000000000000000..86db5e56b0b7cb355c3d6b7d46b9eb76309be983
Binary files /dev/null and b/data/BTK_-_28A___Volvulus.docx differ
diff --git a/data/BTK_-_29A___Lower_GI_Bleed.docx b/data/BTK_-_29A___Lower_GI_Bleed.docx
new file mode 100644
index 0000000000000000000000000000000000000000..1e4b9e83dbc9ec3782eae932edf219d61f388df4
Binary files /dev/null and b/data/BTK_-_29A___Lower_GI_Bleed.docx differ
diff --git a/data/BTK_-_2A___Ventral_Hernia.docx b/data/BTK_-_2A___Ventral_Hernia.docx
new file mode 100644
index 0000000000000000000000000000000000000000..0f4a992a4a70ab0d0712458385f628e04da46108
Binary files /dev/null and b/data/BTK_-_2A___Ventral_Hernia.docx differ
diff --git a/data/BTK_-_30A___Diverticulitis.docx b/data/BTK_-_30A___Diverticulitis.docx
new file mode 100644
index 0000000000000000000000000000000000000000..bc5626432ca4dac83d17e9b3edf6e9579847f3b3
Binary files /dev/null and b/data/BTK_-_30A___Diverticulitis.docx differ
diff --git a/data/BTK_-_31A___Acute_Appendicitis_.docx b/data/BTK_-_31A___Acute_Appendicitis_.docx
new file mode 100644
index 0000000000000000000000000000000000000000..a54d4930b727c7e2bf597fa593dd15511a525ffd
Binary files /dev/null and b/data/BTK_-_31A___Acute_Appendicitis_.docx differ
diff --git a/data/BTK_-_32A___Rectal_Cancer.docx b/data/BTK_-_32A___Rectal_Cancer.docx
new file mode 100644
index 0000000000000000000000000000000000000000..b8fb43e012ab29a6f20d37950ac76bbf80ae919e
Binary files /dev/null and b/data/BTK_-_32A___Rectal_Cancer.docx differ
diff --git a/data/BTK_-_33A___Anal_Cancer.docx b/data/BTK_-_33A___Anal_Cancer.docx
new file mode 100644
index 0000000000000000000000000000000000000000..2dd73a881a88175e83d7c40044fbab90512684a3
Binary files /dev/null and b/data/BTK_-_33A___Anal_Cancer.docx differ
diff --git a/data/BTK_-_34A___Anal_Fissure.docx b/data/BTK_-_34A___Anal_Fissure.docx
new file mode 100644
index 0000000000000000000000000000000000000000..d2837ccacc927f8b6995bc5f6b64a8da850b02be
Binary files /dev/null and b/data/BTK_-_34A___Anal_Fissure.docx differ
diff --git a/data/BTK_-_35A___Hemorrhoids.docx b/data/BTK_-_35A___Hemorrhoids.docx
new file mode 100644
index 0000000000000000000000000000000000000000..3708ee6ea41118a43beb1e4e4b373da43e127332
Binary files /dev/null and b/data/BTK_-_35A___Hemorrhoids.docx differ
diff --git a/data/BTK_-_36A___Perirectal_Abscess.docx b/data/BTK_-_36A___Perirectal_Abscess.docx
new file mode 100644
index 0000000000000000000000000000000000000000..7c947c24f08b8fb453a5a564533138aa2708ca9a
Binary files /dev/null and b/data/BTK_-_36A___Perirectal_Abscess.docx differ
diff --git a/data/BTK_-_37A___Colonoscopy.docx b/data/BTK_-_37A___Colonoscopy.docx
new file mode 100644
index 0000000000000000000000000000000000000000..3b475bc1be1615898cc444341e3847fa4f0b17a4
Binary files /dev/null and b/data/BTK_-_37A___Colonoscopy.docx differ
diff --git a/data/BTK_-_38A___Ductal_Carcinoma_In_Situ__.docx b/data/BTK_-_38A___Ductal_Carcinoma_In_Situ__.docx
new file mode 100644
index 0000000000000000000000000000000000000000..a153b3b575e30ea9fab3e6ea7c985190aa825393
Binary files /dev/null and b/data/BTK_-_38A___Ductal_Carcinoma_In_Situ__.docx differ
diff --git a/data/BTK_-_39A___Breast_Abscess.docx b/data/BTK_-_39A___Breast_Abscess.docx
new file mode 100644
index 0000000000000000000000000000000000000000..8640bc2eaa49834707a746d6d1117d5e6d86eb03
Binary files /dev/null and b/data/BTK_-_39A___Breast_Abscess.docx differ
diff --git a/data/BTK_-_3A___Diaphragmatic_Hiatal_Hernia.docx b/data/BTK_-_3A___Diaphragmatic_Hiatal_Hernia.docx
new file mode 100644
index 0000000000000000000000000000000000000000..1347b71af503040481521ec595059bd3c914f635
Binary files /dev/null and b/data/BTK_-_3A___Diaphragmatic_Hiatal_Hernia.docx differ
diff --git a/data/BTK_-_40A___Breast_Cancer_in_Pregnancy.docx b/data/BTK_-_40A___Breast_Cancer_in_Pregnancy.docx
new file mode 100644
index 0000000000000000000000000000000000000000..d967f66e85d48a86d94af642b6faf3ea9a872da4
Binary files /dev/null and b/data/BTK_-_40A___Breast_Cancer_in_Pregnancy.docx differ
diff --git a/data/BTK_-_41A___Breast_Mass.docx b/data/BTK_-_41A___Breast_Mass.docx
new file mode 100644
index 0000000000000000000000000000000000000000..49d34832e6d735c98cb2c88e01ece452b8a1b92c
Binary files /dev/null and b/data/BTK_-_41A___Breast_Mass.docx differ
diff --git a/data/BTK_-_42A___Inflammatory_Breast_Cancer_.docx b/data/BTK_-_42A___Inflammatory_Breast_Cancer_.docx
new file mode 100644
index 0000000000000000000000000000000000000000..dd52c44f9a13d65bd517bdd6e967ef92baed3108
Binary files /dev/null and b/data/BTK_-_42A___Inflammatory_Breast_Cancer_.docx differ
diff --git a/data/BTK_-_43A___Intraductal_Papilloma_.docx b/data/BTK_-_43A___Intraductal_Papilloma_.docx
new file mode 100644
index 0000000000000000000000000000000000000000..2bf154b15464048f7877ee1dc8b9d655c1f47643
Binary files /dev/null and b/data/BTK_-_43A___Intraductal_Papilloma_.docx differ
diff --git a/data/BTK_-_44A___Lobular_Carcinoma_in_Situ.docx b/data/BTK_-_44A___Lobular_Carcinoma_in_Situ.docx
new file mode 100644
index 0000000000000000000000000000000000000000..9812faf730bfd3d075361a6e36270c03894a5ebf
Binary files /dev/null and b/data/BTK_-_44A___Lobular_Carcinoma_in_Situ.docx differ
diff --git a/data/BTK_-_45A__Breast_Cancer.docx b/data/BTK_-_45A__Breast_Cancer.docx
new file mode 100644
index 0000000000000000000000000000000000000000..8a064ef4168ee9d0e1c4a797dc824c3f21ac6017
Binary files /dev/null and b/data/BTK_-_45A__Breast_Cancer.docx differ
diff --git a/data/BTK_-_46A___Incidental_Adrenal_Mass.docx b/data/BTK_-_46A___Incidental_Adrenal_Mass.docx
new file mode 100644
index 0000000000000000000000000000000000000000..acfcf2091fe79baae9357fd39921349c7ada3511
Binary files /dev/null and b/data/BTK_-_46A___Incidental_Adrenal_Mass.docx differ
diff --git a/data/BTK_-_47A___Hyperparathyroidism.docx b/data/BTK_-_47A___Hyperparathyroidism.docx
new file mode 100644
index 0000000000000000000000000000000000000000..20bcd5caec6080f20b64c5a671236004f59c36c2
Binary files /dev/null and b/data/BTK_-_47A___Hyperparathyroidism.docx differ
diff --git a/data/BTK_-_48A___Hyperthyroidism.docx b/data/BTK_-_48A___Hyperthyroidism.docx
new file mode 100644
index 0000000000000000000000000000000000000000..2bd249802de3a29a2527521b498fda7253ee4f07
Binary files /dev/null and b/data/BTK_-_48A___Hyperthyroidism.docx differ
diff --git a/data/BTK_-_49A___Thyroid_Cancer.docx b/data/BTK_-_49A___Thyroid_Cancer.docx
new file mode 100644
index 0000000000000000000000000000000000000000..bbc7dbb78212a0d89482309f15b55505537d8a1f
Binary files /dev/null and b/data/BTK_-_49A___Thyroid_Cancer.docx differ
diff --git a/data/BTK_-_4A___Inguinal_Femoral_Hernia.docx b/data/BTK_-_4A___Inguinal_Femoral_Hernia.docx
new file mode 100644
index 0000000000000000000000000000000000000000..f306b8ada9c365a4eec28570950e6a06cd14c148
Binary files /dev/null and b/data/BTK_-_4A___Inguinal_Femoral_Hernia.docx differ
diff --git a/data/BTK_-_50A___Thyroid_Nodule.docx b/data/BTK_-_50A___Thyroid_Nodule.docx
new file mode 100644
index 0000000000000000000000000000000000000000..2ded620b7903570f9601d844dfdf9dd48f2f58b9
Binary files /dev/null and b/data/BTK_-_50A___Thyroid_Nodule.docx differ
diff --git a/data/BTK_-_51A___Adrenal_Cancer.docx b/data/BTK_-_51A___Adrenal_Cancer.docx
new file mode 100644
index 0000000000000000000000000000000000000000..fa094395af05804ddcb1ff03f3db8ae4d234775a
Binary files /dev/null and b/data/BTK_-_51A___Adrenal_Cancer.docx differ
diff --git a/data/BTK_-_52A___Hyperaldosteronism.docx b/data/BTK_-_52A___Hyperaldosteronism.docx
new file mode 100644
index 0000000000000000000000000000000000000000..41851b8306ae524b71f946cc978ada21cbe2bce6
Binary files /dev/null and b/data/BTK_-_52A___Hyperaldosteronism.docx differ
diff --git a/data/BTK_-_53A___Hypercortisolism.docx b/data/BTK_-_53A___Hypercortisolism.docx
new file mode 100644
index 0000000000000000000000000000000000000000..6537cd138781e9a6638e70d54b3a21b0188e9bce
Binary files /dev/null and b/data/BTK_-_53A___Hypercortisolism.docx differ
diff --git a/data/BTK_-_54A___Pheochromocytoma.docx b/data/BTK_-_54A___Pheochromocytoma.docx
new file mode 100644
index 0000000000000000000000000000000000000000..421db5b6a217d3cc13668df80a6af6823c724726
Binary files /dev/null and b/data/BTK_-_54A___Pheochromocytoma.docx differ
diff --git a/data/BTK_-_55A___Multiple_Endocrine_Neoplasia.docx b/data/BTK_-_55A___Multiple_Endocrine_Neoplasia.docx
new file mode 100644
index 0000000000000000000000000000000000000000..24cc169c1cd3a1fc95689b2c91b336c59f199a91
Binary files /dev/null and b/data/BTK_-_55A___Multiple_Endocrine_Neoplasia.docx differ
diff --git a/data/BTK_-_56A___Melanoma.docx b/data/BTK_-_56A___Melanoma.docx
new file mode 100644
index 0000000000000000000000000000000000000000..5f09b694dfbc97924025bb0ccd2f99184b4df112
Binary files /dev/null and b/data/BTK_-_56A___Melanoma.docx differ
diff --git a/data/BTK_-_57A___Necrotizing_Soft_Tissue_Infection.docx b/data/BTK_-_57A___Necrotizing_Soft_Tissue_Infection.docx
new file mode 100644
index 0000000000000000000000000000000000000000..8c75fba2f138b40f8a38f7feb62ff1bcf1caf65a
Binary files /dev/null and b/data/BTK_-_57A___Necrotizing_Soft_Tissue_Infection.docx differ
diff --git a/data/BTK_-_58A___Pilonidal_Cyst.docx b/data/BTK_-_58A___Pilonidal_Cyst.docx
new file mode 100644
index 0000000000000000000000000000000000000000..16ce53647a3c677119ebb4dad0bbcd6031778e2e
Binary files /dev/null and b/data/BTK_-_58A___Pilonidal_Cyst.docx differ
diff --git a/data/BTK_-_59A___Cardiac_Arrest.docx b/data/BTK_-_59A___Cardiac_Arrest.docx
new file mode 100644
index 0000000000000000000000000000000000000000..c8e3d0be269bdff52745a9f4f95f7eec6c8ba90b
Binary files /dev/null and b/data/BTK_-_59A___Cardiac_Arrest.docx differ
diff --git a/data/BTK_-_5A___Umbilical_Hernia___Cirrhotic.docx b/data/BTK_-_5A___Umbilical_Hernia___Cirrhotic.docx
new file mode 100644
index 0000000000000000000000000000000000000000..9ae088d79994f12eb84846a7542a1a59592755c3
Binary files /dev/null and b/data/BTK_-_5A___Umbilical_Hernia___Cirrhotic.docx differ
diff --git a/data/BTK_-_60A___Cardiac_Arrhythmia.docx b/data/BTK_-_60A___Cardiac_Arrhythmia.docx
new file mode 100644
index 0000000000000000000000000000000000000000..c8b61274ce0290a2d0184bda60dfced99f86efd4
Binary files /dev/null and b/data/BTK_-_60A___Cardiac_Arrhythmia.docx differ
diff --git a/data/BTK_-_61A___Liver_Failure.docx b/data/BTK_-_61A___Liver_Failure.docx
new file mode 100644
index 0000000000000000000000000000000000000000..ebe84741bd47059094d8c171d0ae17b3f7733c92
Binary files /dev/null and b/data/BTK_-_61A___Liver_Failure.docx differ
diff --git a/data/BTK_-_62A___Neurogenic_Shock_and_Delerium.docx b/data/BTK_-_62A___Neurogenic_Shock_and_Delerium.docx
new file mode 100644
index 0000000000000000000000000000000000000000..ca0fa77a0007baf1f05b9f7bd76cf4752264f265
Binary files /dev/null and b/data/BTK_-_62A___Neurogenic_Shock_and_Delerium.docx differ
diff --git a/data/BTK_-_63A___Renal_Failure.docx b/data/BTK_-_63A___Renal_Failure.docx
new file mode 100644
index 0000000000000000000000000000000000000000..a2fb5d7ba4f18ad552d74227932f5d1b2b150bb6
Binary files /dev/null and b/data/BTK_-_63A___Renal_Failure.docx differ
diff --git a/data/BTK_-_64A___Sepsis.docx b/data/BTK_-_64A___Sepsis.docx
new file mode 100644
index 0000000000000000000000000000000000000000..dc540d82ec692323d152dc1d7907604b2078e275
Binary files /dev/null and b/data/BTK_-_64A___Sepsis.docx differ
diff --git a/data/BTK_-_65A___ARDS.docx b/data/BTK_-_65A___ARDS.docx
new file mode 100644
index 0000000000000000000000000000000000000000..807f7d42ce46f83c97bbc29c6a85e88bfcc49ec9
Binary files /dev/null and b/data/BTK_-_65A___ARDS.docx differ
diff --git a/data/BTK_-_66A___Vascular_Neck_Injury.docx b/data/BTK_-_66A___Vascular_Neck_Injury.docx
new file mode 100644
index 0000000000000000000000000000000000000000..91a2f3d4bcf555c03c0637ed864a9b850f4ec660
Binary files /dev/null and b/data/BTK_-_66A___Vascular_Neck_Injury.docx differ
diff --git a/data/BTK_-_67A___Vascular_Abdomen_Injury.docx b/data/BTK_-_67A___Vascular_Abdomen_Injury.docx
new file mode 100644
index 0000000000000000000000000000000000000000..60b13c540a7669e4dab6fa79b55da5f1945c5609
Binary files /dev/null and b/data/BTK_-_67A___Vascular_Abdomen_Injury.docx differ
diff --git a/data/BTK_-_68A___Lower_Extremity_Vascular_Injury.docx b/data/BTK_-_68A___Lower_Extremity_Vascular_Injury.docx
new file mode 100644
index 0000000000000000000000000000000000000000..99d984a93aa7eea9135806a91d9a1032b50ee4e6
Binary files /dev/null and b/data/BTK_-_68A___Lower_Extremity_Vascular_Injury.docx differ
diff --git a/data/BTK_-_69A___Cardiac_Trauma.docx b/data/BTK_-_69A___Cardiac_Trauma.docx
new file mode 100644
index 0000000000000000000000000000000000000000..9ef30a270f422a57a5e2db0803e3548007649fc3
Binary files /dev/null and b/data/BTK_-_69A___Cardiac_Trauma.docx differ
diff --git a/data/BTK_-_6A___Choledochal_Cyst.docx b/data/BTK_-_6A___Choledochal_Cyst.docx
new file mode 100644
index 0000000000000000000000000000000000000000..c60665d3b8686b6df3de0811f62e6372f8be4db6
Binary files /dev/null and b/data/BTK_-_6A___Choledochal_Cyst.docx differ
diff --git a/data/BTK_-_70A___Duodenum_and_Pancreas_Trauma.docx b/data/BTK_-_70A___Duodenum_and_Pancreas_Trauma.docx
new file mode 100644
index 0000000000000000000000000000000000000000..7df98a01c85cad269cb77c17c3cbc3be034d6e9c
Binary files /dev/null and b/data/BTK_-_70A___Duodenum_and_Pancreas_Trauma.docx differ
diff --git a/data/BTK_-_71A___Esophagus_and_Trachea_Trauma_Empyema.docx b/data/BTK_-_71A___Esophagus_and_Trachea_Trauma_Empyema.docx
new file mode 100644
index 0000000000000000000000000000000000000000..17b55d0e47ee7bb1657bc2743e663f5cf6230364
Binary files /dev/null and b/data/BTK_-_71A___Esophagus_and_Trachea_Trauma_Empyema.docx differ
diff --git a/data/BTK_-_72A___Liver_Trauma.docx b/data/BTK_-_72A___Liver_Trauma.docx
new file mode 100644
index 0000000000000000000000000000000000000000..8b794f7aa6ed26d17b3a0fea6a471c7c84331d01
Binary files /dev/null and b/data/BTK_-_72A___Liver_Trauma.docx differ
diff --git a/data/BTK_-_73A___Pelvic_Trauma_Pneumothorax.docx b/data/BTK_-_73A___Pelvic_Trauma_Pneumothorax.docx
new file mode 100644
index 0000000000000000000000000000000000000000..37f76242413bf648630914ee25e9c1cb436f8b21
Binary files /dev/null and b/data/BTK_-_73A___Pelvic_Trauma_Pneumothorax.docx differ
diff --git a/data/BTK_-_74A___Pregnant_Trauma.docx b/data/BTK_-_74A___Pregnant_Trauma.docx
new file mode 100644
index 0000000000000000000000000000000000000000..ec5030c5b4d9cfc70e5affbdb4289375b0b3c5e6
Binary files /dev/null and b/data/BTK_-_74A___Pregnant_Trauma.docx differ
diff --git a/data/BTK_-_75A___Rectal_Bladder_Urethral_and_Kidney_Trauma.docx b/data/BTK_-_75A___Rectal_Bladder_Urethral_and_Kidney_Trauma.docx
new file mode 100644
index 0000000000000000000000000000000000000000..afc0f5fa1d570aae8aef1669c6015f08c819ae1f
Binary files /dev/null and b/data/BTK_-_75A___Rectal_Bladder_Urethral_and_Kidney_Trauma.docx differ
diff --git a/data/BTK_-_76A___Spleen_Trauma_Geriatric_Trauma_Brain_Death.docx b/data/BTK_-_76A___Spleen_Trauma_Geriatric_Trauma_Brain_Death.docx
new file mode 100644
index 0000000000000000000000000000000000000000..b078eba5f3cb2a416114c2cc862ad3f9b23f4212
Binary files /dev/null and b/data/BTK_-_76A___Spleen_Trauma_Geriatric_Trauma_Brain_Death.docx differ
diff --git a/data/BTK_-_77A___Burn.docx b/data/BTK_-_77A___Burn.docx
new file mode 100644
index 0000000000000000000000000000000000000000..cce49f0af42edb264c896f625259def93e3fc13c
Binary files /dev/null and b/data/BTK_-_77A___Burn.docx differ
diff --git a/data/BTK_-_78A___Acute_Limb_Ischemia_.docx b/data/BTK_-_78A___Acute_Limb_Ischemia_.docx
new file mode 100644
index 0000000000000000000000000000000000000000..d34ea176e504887135bc1fad0fe0a585235e4f7d
Binary files /dev/null and b/data/BTK_-_78A___Acute_Limb_Ischemia_.docx differ
diff --git a/data/BTK_-_79A___Carotid_Stenosis_.docx b/data/BTK_-_79A___Carotid_Stenosis_.docx
new file mode 100644
index 0000000000000000000000000000000000000000..a7f35e699b0d69ac6429bcc56a6e788a541a4564
Binary files /dev/null and b/data/BTK_-_79A___Carotid_Stenosis_.docx differ
diff --git a/data/BTK_-_7A___Malignant_Biliary.docx b/data/BTK_-_7A___Malignant_Biliary.docx
new file mode 100644
index 0000000000000000000000000000000000000000..1c70d9a906a41222b365155e4f00a1d0b7a586a6
Binary files /dev/null and b/data/BTK_-_7A___Malignant_Biliary.docx differ
diff --git a/data/BTK_-_80A___Chronic_Limb_Threatening_Ischemia.docx b/data/BTK_-_80A___Chronic_Limb_Threatening_Ischemia.docx
new file mode 100644
index 0000000000000000000000000000000000000000..b4b7b646795e9e081d276f14fa328b73e5e53543
Binary files /dev/null and b/data/BTK_-_80A___Chronic_Limb_Threatening_Ischemia.docx differ
diff --git a/data/BTK_-_81A___Chronic_Venous_Insufficiency.docx b/data/BTK_-_81A___Chronic_Venous_Insufficiency.docx
new file mode 100644
index 0000000000000000000000000000000000000000..513792ca94d12e110e3c8f99875a59cf49a457e8
Binary files /dev/null and b/data/BTK_-_81A___Chronic_Venous_Insufficiency.docx differ
diff --git a/data/BTK_-_83A___Vascular_Access_.docx b/data/BTK_-_83A___Vascular_Access_.docx
new file mode 100644
index 0000000000000000000000000000000000000000..9d53670ef154ff36b5b65fa0455747a5dd2b9a80
Binary files /dev/null and b/data/BTK_-_83A___Vascular_Access_.docx differ
diff --git a/data/BTK_-_84A___Abdominal_Mass_Pediatrics.docx b/data/BTK_-_84A___Abdominal_Mass_Pediatrics.docx
new file mode 100644
index 0000000000000000000000000000000000000000..9340786b3cf0651b85df7dd2324ea9b2ec0f2c75
Binary files /dev/null and b/data/BTK_-_84A___Abdominal_Mass_Pediatrics.docx differ
diff --git a/data/BTK_-_85A___Appendicitis_Pediatrics.docx b/data/BTK_-_85A___Appendicitis_Pediatrics.docx
new file mode 100644
index 0000000000000000000000000000000000000000..8157833cda6e4ec67f9255afcfea75c7c32aa65f
Binary files /dev/null and b/data/BTK_-_85A___Appendicitis_Pediatrics.docx differ
diff --git a/data/BTK_-_86A___Inguinal_Hernia_Pediatrics.docx b/data/BTK_-_86A___Inguinal_Hernia_Pediatrics.docx
new file mode 100644
index 0000000000000000000000000000000000000000..02d16e69b5e92d2ef04314b0ff564fb880d723f1
Binary files /dev/null and b/data/BTK_-_86A___Inguinal_Hernia_Pediatrics.docx differ
diff --git a/data/BTK_-_87A___Intussusception_Pediatrics.docx b/data/BTK_-_87A___Intussusception_Pediatrics.docx
new file mode 100644
index 0000000000000000000000000000000000000000..01adad7a29680d32bdf7259bdff2b017ec1afe68
Binary files /dev/null and b/data/BTK_-_87A___Intussusception_Pediatrics.docx differ
diff --git a/data/BTK_-_88A___Malrotation_Pediatrics.docx b/data/BTK_-_88A___Malrotation_Pediatrics.docx
new file mode 100644
index 0000000000000000000000000000000000000000..f0034d9a96894965924b749c2947101e96504e0d
Binary files /dev/null and b/data/BTK_-_88A___Malrotation_Pediatrics.docx differ
diff --git a/data/BTK_-_89A___Meckel_s_Diverticulum_Pediatrics.docx b/data/BTK_-_89A___Meckel_s_Diverticulum_Pediatrics.docx
new file mode 100644
index 0000000000000000000000000000000000000000..ca216269f11c447fb781ccd59005194e09d1a81d
Binary files /dev/null and b/data/BTK_-_89A___Meckel_s_Diverticulum_Pediatrics.docx differ
diff --git a/data/BTK_-_8A___Benign_Biliary.docx b/data/BTK_-_8A___Benign_Biliary.docx
new file mode 100644
index 0000000000000000000000000000000000000000..1ff3c94172f7a3aa09153dd73add00168f2446e3
Binary files /dev/null and b/data/BTK_-_8A___Benign_Biliary.docx differ
diff --git a/data/BTK_-_90A___Pyloric_Stenosis_Pediatrics.docx b/data/BTK_-_90A___Pyloric_Stenosis_Pediatrics.docx
new file mode 100644
index 0000000000000000000000000000000000000000..2a9891b18aa0d545cc175c73f604c327c584384e
Binary files /dev/null and b/data/BTK_-_90A___Pyloric_Stenosis_Pediatrics.docx differ
diff --git a/data/BTK_-_91A___Umbilical_Hernia_Pediatrics.docx b/data/BTK_-_91A___Umbilical_Hernia_Pediatrics.docx
new file mode 100644
index 0000000000000000000000000000000000000000..708de74ca62219976e6bf1530cfa5d5bf975089e
Binary files /dev/null and b/data/BTK_-_91A___Umbilical_Hernia_Pediatrics.docx differ
diff --git a/data/BTK_-_92A___Ectopic_Pregnancy_.docx b/data/BTK_-_92A___Ectopic_Pregnancy_.docx
new file mode 100644
index 0000000000000000000000000000000000000000..f03cc426bcd7935b8a03e776f893390dd66a5aba
Binary files /dev/null and b/data/BTK_-_92A___Ectopic_Pregnancy_.docx differ
diff --git a/data/BTK_-_93A___Ruptured_AAA.docx b/data/BTK_-_93A___Ruptured_AAA.docx
new file mode 100644
index 0000000000000000000000000000000000000000..316d219152a1455d5ed42d32f980b8a53f3b3d69
Binary files /dev/null and b/data/BTK_-_93A___Ruptured_AAA.docx differ
diff --git a/data/BTK_-_94A___Acute_Aortic_Dissection.docx b/data/BTK_-_94A___Acute_Aortic_Dissection.docx
new file mode 100644
index 0000000000000000000000000000000000000000..7fdbcaf86858b66a744cd346768c2ff951178886
Binary files /dev/null and b/data/BTK_-_94A___Acute_Aortic_Dissection.docx differ
diff --git a/data/BTK_-_95A___Cancer_in_a_Polyp.docx b/data/BTK_-_95A___Cancer_in_a_Polyp.docx
new file mode 100644
index 0000000000000000000000000000000000000000..54ce8059bad9f2f72c6f1c4d7a56f0608ff851a7
Binary files /dev/null and b/data/BTK_-_95A___Cancer_in_a_Polyp.docx differ
diff --git a/data/BTK_-_96A___Anastomotic_Leak.docx b/data/BTK_-_96A___Anastomotic_Leak.docx
new file mode 100644
index 0000000000000000000000000000000000000000..304a4b78d55b8950f949661c799b6aa37c32beeb
Binary files /dev/null and b/data/BTK_-_96A___Anastomotic_Leak.docx differ
diff --git a/data/BTK_-_97A___Merkel_Cell_Carcinoma.docx b/data/BTK_-_97A___Merkel_Cell_Carcinoma.docx
new file mode 100644
index 0000000000000000000000000000000000000000..f017c26e7b4a5a6a0e6c195c368f9acaa0f5ed5d
Binary files /dev/null and b/data/BTK_-_97A___Merkel_Cell_Carcinoma.docx differ
diff --git a/data/BTK_-_98A___Zenkers_Diverticulum.docx b/data/BTK_-_98A___Zenkers_Diverticulum.docx
new file mode 100644
index 0000000000000000000000000000000000000000..bf154d669c4210af8ea1f4fad74756200885dc3b
Binary files /dev/null and b/data/BTK_-_98A___Zenkers_Diverticulum.docx differ
diff --git a/data/BTK_-_99A___Enteric_Feeding_Tube_Troubles.docx b/data/BTK_-_99A___Enteric_Feeding_Tube_Troubles.docx
new file mode 100644
index 0000000000000000000000000000000000000000..73a01e782474b19d8364ad6c43d180ba6eddbdd8
Binary files /dev/null and b/data/BTK_-_99A___Enteric_Feeding_Tube_Troubles.docx differ
diff --git a/data/BTK_-_9A___Hepatic_Neoplasms.docx b/data/BTK_-_9A___Hepatic_Neoplasms.docx
new file mode 100644
index 0000000000000000000000000000000000000000..576737fd7e4e1d5018faa87711fae6d8185be7a1
Binary files /dev/null and b/data/BTK_-_9A___Hepatic_Neoplasms.docx differ
diff --git a/notebooks/.ipynb_checkpoints/demo-checkpoint.ipynb b/notebooks/.ipynb_checkpoints/demo-checkpoint.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..eb3685c1080f69a464a13d40ee436676d4a3cfd1
--- /dev/null
+++ b/notebooks/.ipynb_checkpoints/demo-checkpoint.ipynb
@@ -0,0 +1,853 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "intro-markdown",
+ "metadata": {},
+ "source": [
+ "# RAG Pipeline Demo: Surgery Oral Board Simulation\n",
+ "\n",
+ "This notebook demonstrates the usage of the RAG (Retrieval-Augmented Generation) pipeline developed for simulating surgery oral board scenarios. It utilizes the custom Python modules located in the `src/` directory.\n",
+ "\n",
+ "**Pipeline Components:**\n",
+ "* `data_processing`: Loads and preprocesses raw case data from `.docx` files.\n",
+ "* `ClinicalCaseProcessor`: Creates embeddings for clinical cases.\n",
+ "* `ClinicalCaseRetriever`: Retrieves relevant cases based on semantic similarity to a query.\n",
+ "* `AnswerEvaluator`: Uses an LLM to evaluate user responses against expected answers.\n",
+ "* `OralExamSimulator`: Orchestrates the simulation flow."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "setup-markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Setup\n",
+ "\n",
+ "Import necessary libraries and custom modules. Define constants and handle Hugging Face Hub authentication."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "setup-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-04-22 12:56:15.809380: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
+ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
+ "E0000 00:00:1745340976.161054 614274 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
+ "E0000 00:00:1745340976.276502 614274 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+ "W0000 00:00:1745340977.166337 614274 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
+ "W0000 00:00:1745340977.166411 614274 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
+ "W0000 00:00:1745340977.166413 614274 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
+ "W0000 00:00:1745340977.166414 614274 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
+ "2025-04-22 12:56:17.174186: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import json\n",
+ "import sys \n",
+ "import re\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from dotenv import load_dotenv\n",
+ "from huggingface_hub import login\n",
+ "from sklearn.metrics import ndcg_score\n",
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
+ "from IPython.display import display, Markdown\n",
+ "\n",
+ "# --- Add project root to sys.path ---\n",
+ "project_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))\n",
+ "if project_root not in sys.path:\n",
+ " sys.path.append(project_root)\n",
+ "\n",
+ "# --- Custom Module Imports ---\n",
+ "from src.data_processing import process_all_cases, ClinicalCaseProcessor\n",
+ "from src.retriever import ClinicalCaseRetriever, DummyRetriever\n",
+ "from src.evaluator import AnswerEvaluator\n",
+ "from src.simulator import OralExamSimulator\n",
+ "from src.evaluation_utils import retrieval_metrics\n",
+ "from src.synthetic_generator import generate_synthetic_case, process_synthetic_data\n",
+ "\n",
+ "# --- Configuration ---\n",
+ "DATA_FOLDER = \"../data/\"\n",
+ "PROCESSED_DATA_PATH = \"../processed_clinical_cases\"\n",
+ "EVALUATOR_MODEL_ID = \"meta-llama/Llama-3.2-3B-Instruct\"\n",
+ "EMBEDDING_MODEL_ID = \"all-MiniLM-L6-v2\"\n",
+ "\n",
+ "# Hugging Face Hub Authentication (using .env file in project root)\n",
+ "dotenv_path = os.path.join(project_root, 'hf_login.env')\n",
+ "load_dotenv(dotenv_path=dotenv_path)\n",
+ "hf_key = os.getenv(\"HF_KEY\")\n",
+ "\n",
+ "# --- Set Cache Directory (Optional) ---\n",
+ "cache_dir = \"/scratch/krb3ym/models/cache\" # Keep your specific path\n",
+ "os.environ['HF_HOME'] = cache_dir\n",
+ "os.makedirs(cache_dir, exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "load-data-markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Load and Preprocess Raw Data\n",
+ "\n",
+ "Load the clinical cases from the `.docx` files into a pandas DataFrame using the utility function from `src.data_processing`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "load-data-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing case files from: ../data/\n",
+ "Warning: Could not parse filename format: BTK_-_45A__Breast_Cancer.docx\n",
+ "Warning: Could not parse filename format: BTK_10A___Pancreatic_Neoplasm.docx\n",
+ "Loaded 98 cases. Displaying head:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " case_id | \n",
+ " clinical_presentation | \n",
+ " turn_id | \n",
+ " question | \n",
+ " answer | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 85A | \n",
+ " Appendicitis Pediatrics | \n",
+ " 1 | \n",
+ " You were called to the emergency department to... | \n",
+ " Okay, well, I would begin by evaluating the pa... | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 85A | \n",
+ " Appendicitis Pediatrics | \n",
+ " 2 | \n",
+ " The patient is febrile at 101.2 degrees, heart... | \n",
+ " Okay, does he have a positive psoas or a Rovsi... | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 85A | \n",
+ " Appendicitis Pediatrics | \n",
+ " 3 | \n",
+ " Yes, both are positive. | \n",
+ " I'm concerned about acute appendicitis, but th... | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 85A | \n",
+ " Appendicitis Pediatrics | \n",
+ " 4 | \n",
+ " Labs are in order. White blood cell count of 1... | \n",
+ " Well, this is diagnostic of acute appendicitis... | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 85A | \n",
+ " Appendicitis Pediatrics | \n",
+ " 5 | \n",
+ " Are there any other therapeutic options for th... | \n",
+ " Yes, the patient has early acute appendicitis ... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " case_id clinical_presentation turn_id \\\n",
+ "0 85A Appendicitis Pediatrics 1 \n",
+ "1 85A Appendicitis Pediatrics 2 \n",
+ "2 85A Appendicitis Pediatrics 3 \n",
+ "3 85A Appendicitis Pediatrics 4 \n",
+ "4 85A Appendicitis Pediatrics 5 \n",
+ "\n",
+ " question \\\n",
+ "0 You were called to the emergency department to... \n",
+ "1 The patient is febrile at 101.2 degrees, heart... \n",
+ "2 Yes, both are positive. \n",
+ "3 Labs are in order. White blood cell count of 1... \n",
+ "4 Are there any other therapeutic options for th... \n",
+ "\n",
+ " answer \n",
+ "0 Okay, well, I would begin by evaluating the pa... \n",
+ "1 Okay, does he have a positive psoas or a Rovsi... \n",
+ "2 I'm concerned about acute appendicitis, but th... \n",
+ "3 Well, this is diagnostic of acute appendicitis... \n",
+ "4 Yes, the patient has early acute appendicitis ... "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "dataframe_btk = process_all_cases(DATA_FOLDER)\n",
+ "print(f\"Loaded {dataframe_btk['case_id'].nunique()} cases. Displaying head:\")\n",
+ "display(dataframe_btk.head())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "prepare-rag-markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Prepare RAG Dataset with Embeddings\n",
+ "\n",
+ "Use the `ClinicalCaseProcessor` to process the DataFrame, generate embeddings for each case summary, and save the result as a Hugging Face Dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "prepare-rag-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Initializing ClinicalCaseProcessor with model: all-MiniLM-L6-v2\n",
+ "Using device: cuda\n",
+ "Using provided DataFrame.\n",
+ "Raw data shape: (973, 5)\n",
+ "Grouping data by case...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Processing Cases: 100%|██████████| 98/98 [00:00<00:00, 6289.76it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processed data into 98 unique cases.\n",
+ "Generating embeddings for 98 case summaries...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Embedding Batches: 100%|██████████| 7/7 [00:00<00:00, 11.43it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated embeddings with shape: (98, 384)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "30e6d7f6e941472a96f0e3a8532cd63d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Saving the dataset (0/1 shards): 0%| | 0/98 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processed dataset saved successfully to ../processed_clinical_cases\n",
+ "Dataset({\n",
+ " features: ['case_id', 'clinical_presentation', 'questions', 'answers', 'case_summary', 'embeddings'],\n",
+ " num_rows: 98\n",
+ "})\n"
+ ]
+ }
+ ],
+ "source": [
+ "processor = ClinicalCaseProcessor(model_name=EMBEDDING_MODEL_ID)\n",
+ "processed_dataset = processor.preprocess_data(dataframe_btk, output_path=PROCESSED_DATA_PATH)\n",
+ "print(processed_dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "init-pipeline-markdown",
+ "metadata": {},
+ "source": [
+ "## 4. Initialize RAG Pipeline Components\n",
+ "\n",
+ "Instantiate the retriever, evaluator, and simulator classes."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "init-pipeline-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Initializing ClinicalCaseRetriever with model: all-MiniLM-L6-v2\n",
+ "Dataset loaded successfully from disk: ../processed_clinical_cases\n",
+ "Dataset features: {'case_id': Value(dtype='string', id=None), 'clinical_presentation': Value(dtype='string', id=None), 'questions': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'answers': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'case_summary': Value(dtype='string', id=None), 'embeddings': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None)}\n",
+ "Number of cases in dataset: 98\n",
+ "Using device: cuda\n",
+ "Loaded 98 cases with embeddings of shape (98, 384)\n",
+ "Initializing AnswerEvaluator with model: meta-llama/Llama-3.2-3B-Instruct\n",
+ "Set pad_token to eos_token\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "8e6040279e1849e8b55c82ad0a07ecaa",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AnswerEvaluator model loaded successfully on device: cuda:0\n"
+ ]
+ }
+ ],
+ "source": [
+ "retriever = ClinicalCaseRetriever(dataset_path=PROCESSED_DATA_PATH, model_name=EMBEDDING_MODEL_ID)\n",
+ "evaluator = AnswerEvaluator(model_id=EVALUATOR_MODEL_ID)\n",
+ "simulator = OralExamSimulator(retriever, evaluator)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "run-sim-markdown",
+ "metadata": {},
+ "source": [
+ "## 5. Run Simulation with RAG\n",
+ "\n",
+ "Demonstrate the core simulation loop: start a case based on a query, process user responses, and get feedback. \n",
+ "Let's take the example of 'bowel intussusception in a child'."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "run-sim-start-markdown",
+ "metadata": {},
+ "source": [
+ "### 5.1 Start a New Case"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "run-sim-start-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "**➡️ Please enter the clinical scenario that you would like to be examined on:**"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdin",
+ "output_type": "stream",
+ "text": [
+ " Your Answer: bowel intussusception in child\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "--------------------------------------------------\n",
+ "Attempting to start new case | Query: 'bowel intussusception in child' | Index: None\n",
+ "Encoding query: 'Clinical case about bowel intussusception in child'\n",
+ "Retrieved 1 cases with similarity scores:\n",
+ "- Intussusception Pediatrics: 0.7762\n",
+ "Retrieved case via query ('bowel intussusception in child') with score 0.7762: Intussusception Pediatrics\n",
+ "Case successfully started. Total questions: 12\n",
+ "--------------------------------------------------\n",
+ "\n",
+ "Case Started: Intussusception Pediatrics\n",
+ "Total Questions: 12\n",
+ "\n",
+ "--- Question 1 --- \n",
+ "You're called to the emergency department to evaluate an 18-month-old boy with an eight-hour history of intermittent intense abdominal pain. His mother notes that the pain starts abruptly, and that the kid will draw his knees up to his chest during these painful episodes. Pain lasts for 30 minutes at a time and spontaneously resolves. He's had one episode of non-bloody emesis. He's not had diarrhea or bloody stools.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# --- Start Simulation ---\n",
+ "display(Markdown(f\"**➡️ Please enter the clinical scenario that you would like to be examined on:**\"))\n",
+ "query = input(\" Your Answer: \")\n",
+ "case_info = simulator.start_new_case(clinical_query=query)\n",
+ "\n",
+ "# Directly print the first question (assuming success)\n",
+ "print(f\"\\nCase Started: {case_info.get('clinical_presentation', 'N/A')}\")\n",
+ "print(f\"Total Questions: {case_info.get('total_questions', 'N/A')}\")\n",
+ "print(\"\\n--- Question 1 --- \")\n",
+ "print(case_info.get('current_question', 'Error: No question found'))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "run-sim-turn1-markdown",
+ "metadata": {},
+ "source": [
+ "### 5.2 Process User Response (Turn 1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "44c6e822-5508-4db2-8700-79f25dc14e05",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "**➡️ Your Turn (Question 1/12)**"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdin",
+ "output_type": "stream",
+ "text": [
+ " Your Answer: Perform physical exam, looking specifically for abdominal signs, order labs, and get vitals\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "⏳ Processing User Answer...\n",
+ "--------------------------------------------------\n",
+ "Processing response for Question 1/12\n",
+ "User Response: Perform physical exam, looking specifically for abdominal signs, order labs, and get vitals\n",
+ "Expected Answer: I'd begin by obtaining vital signs and performing a comprehensive history and physical examination, focusing on his history, birth history, and my exam focusing on his abdominal exam, looking for hernias, and doing a rectal exam as well.\n",
+ "Generated Feedback: ## Step 1: Evaluate the resident's response against the expected answer.\n",
+ "The resident's response is: \"Perform physical exam, looking specifically for abdominal signs, order labs, and get vitals.\" This response is incomplete as it does not mention obtaining a comprehensive history and physical examination, focusing on the patient's history, birth history, and abdominal exam, nor does it mention a rectal exam.\n",
+ "\n",
+ "## Step 2: Identify the key points that were mentioned vs. missed.\n",
+ "The resident's response missed the key points of obtaining a comprehensive history and physical examination, focusing on the patient's history, birth history, and abdominal exam, and performing a rectal exam.\n",
+ "\n",
+ "## Step 3: Evaluate the accuracy and clarity of the clinical reasoning.\n",
+ "Next question (1/12): The patient's tachycardic, the rest of the vital signs are normal. Your exam reveals a toddler in the fetal position on the stretcher. He appears to be in moderate distress. He has diffuse abdominal tenderness and voluntary guarding.\n",
+ "--------------------------------------------------\n",
+ "\n",
+ "❓Next Question (1/12)\n",
+ "\n",
+ "The patient's tachycardic, the rest of the vital signs are normal. Your exam reveals a toddler in the fetal position on the stretcher. He appears to be in moderate distress. He has diffuse abdominal tenderness and voluntary guarding.\n"
+ ]
+ }
+ ],
+ "source": [
+ "display(Markdown(f\"**➡️ Your Turn (Question {simulator.current_question_idx + 1}/{len(simulator.current_case['questions'])})**\"))\n",
+ "user_answer = input(\" Your Answer: \")\n",
+ "\n",
+ "# --- Process Response ---\n",
+ "print(f\"\\n⏳ Processing User Answer...\")\n",
+ "result = simulator.process_user_response(user_answer)\n",
+ "\n",
+ "# Next question\n",
+ "q_num = result.get('question_number', '?')\n",
+ "next_q_text = result.get('next_question', '*Error: No next question found*')\n",
+ "print(f\"\\n❓Next Question ({q_num}/{len(simulator.current_case['questions'])})\\n\\n{next_q_text}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "run-sim-turn2-markdown",
+ "metadata": {},
+ "source": [
+ "### 5.3 Process User Response (Turn 2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "run-sim-turn2-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "**➡️ Your Turn (Question 2/12)**"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdin",
+ "output_type": "stream",
+ "text": [
+ " Your Answer: Ask for labs, get imaging\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "⏳ Processing User Answer...\n",
+ "--------------------------------------------------\n",
+ "Processing response for Question 2/12\n",
+ "User Response: Ask for labs, get imaging\n",
+ "Expected Answer: I'd like to know a little bit more about the history. Has he recently been ill or had any sick contacts? When did he pass stool or gas last? Is this a first such episode? Is there any family history of GI conditions or malignancy?\n",
+ "Generated Feedback: This is the prompt for the oral board exam question. The resident’s response is:\n",
+ "\n",
+ "I'd like to know a little bit more about the history. Has he recently been ill or had any sick contacts? When did he pass stool or gas last? Is this a first such episode? Is there any family history of GI conditions or malignancy?\n",
+ "\n",
+ "The expected answer is:\n",
+ "\n",
+ "Ask for labs, get imaging\n",
+ "\n",
+ "The resident’s response is:\n",
+ "\n",
+ "I'd like to know a little bit more about the history. Has he recently been ill or had any sick contacts? When did he pass stool or gas last? Is this a first such episode? Is there any family history of GI conditions or malignancy?\n",
+ "\n",
+ "The resident’s response is identical to the expected answer\n",
+ "Next question (2/12): He has had a mild upper respiratory infection for the past week, but he's been eating normally. His last bowel movement was yesterday. This is the first time he's ever had symptoms like this, and his family and personal history is otherwise unremarkable.\n",
+ "--------------------------------------------------\n",
+ "\n",
+ "**❓ Next Question (2/12)**\n",
+ "\n",
+ "He has had a mild upper respiratory infection for the past week, but he's been eating normally. His last bowel movement was yesterday. This is the first time he's ever had symptoms like this, and his family and personal history is otherwise unremarkable.\n"
+ ]
+ }
+ ],
+ "source": [
+ "display(Markdown(f\"**➡️ Your Turn (Question {simulator.current_question_idx + 1}/{len(simulator.current_case['questions'])})**\"))\n",
+ "user_answer = input(\" Your Answer: \")\n",
+ "\n",
+ "print(f\"\\n⏳ Processing User Answer...\")\n",
+ "result = simulator.process_user_response(user_answer)\n",
+ "\n",
+ "q_num = result.get('question_number', '?')\n",
+ "next_q_text = result.get('next_question', '*Error: No next question found*')\n",
+ "print(f\"\\n❓ Next Question ({q_num}/{len(simulator.current_case['questions'])})\\n\\n{next_q_text}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "run-sim-summary-markdown",
+ "metadata": {},
+ "source": [
+ "### 5.4 Session Summary and Saving"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "run-sim-summary-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "--- Generating Session Summary --- \n",
+ "Case: Intussusception Pediatrics\n",
+ "Case ID: 87A\n",
+ "Total Questions in Case: 12\n",
+ "Number of Interactions Logged: 8\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"\\n--- Generating Session Summary --- \")\n",
+ "session_summary = simulator.generate_session_summary()\n",
+ "\n",
+ "print(f\"Case: {session_summary.get('case', 'N/A')}\")\n",
+ "print(f\"Case ID: {session_summary.get('case_id', 'N/A')}\")\n",
+ "print(f\"Total Questions in Case: {session_summary.get('total_questions_in_case', 'N/A')}\")\n",
+ "print(f\"Number of Interactions Logged: {len(session_summary.get('interaction_history', []))}\")\n",
+ "\n",
+ "# Display full session interaction\n",
+ "#print(json.dumps(session_summary, indent=2)) # Keep commented out or remove"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "eval-retrieval-markdown",
+ "metadata": {},
+ "source": [
+ "## 6. Evaluate Retrieval Performance\n",
+ "\n",
+ "Creating dummy dataset to evaluate RAG performance"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "eval-retrieval-run-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Calculating retrieval metrics for 5 queries (k=5)...\n",
+ "\n",
+ "Processing query 1/5: 'appendix inflammation in a child' (Expected ID: '85A')\n",
+ "Encoding query: 'Clinical case about appendix inflammation in a child'\n",
+ "Retrieved 5 cases with similarity scores:\n",
+ "- Appendicitis Pediatrics: 0.7076\n",
+ "- Acute Appendicitis: 0.5978\n",
+ "- Intussusception Pediatrics: 0.5745\n",
+ "- Meckel s Diverticulum Pediatrics: 0.5277\n",
+ "- Abdominal Mass Pediatrics: 0.4705\n",
+ "Retrieved IDs: ['85A', '31A', '87A', '89A', '84A']\n",
+ "Retrieved Scores: [0.7076, 0.5978, 0.5745, 0.5277, 0.4705]\n",
+ "Hit: 1, Rank: 1, NDCG@5: 1.0000\n",
+ "\n",
+ "Processing query 2/5: 'Pyloric Stenosis in a child' (Expected ID: '90A')\n",
+ "Encoding query: 'Clinical case about Pyloric Stenosis in a child'\n",
+ "Retrieved 5 cases with similarity scores:\n",
+ "- Pyloric Stenosis Pediatrics: 0.7554\n",
+ "- Carotid Stenosis: 0.4458\n",
+ "- Pheochromocytoma: 0.4139\n",
+ "- Small Bowel Obstruction: 0.3876\n",
+ "- Intussusception Pediatrics: 0.3855\n",
+ "Retrieved IDs: ['90A', '79A', '54A', '23A', '87A']\n",
+ "Retrieved Scores: [0.7554, 0.4458, 0.4139, 0.3876, 0.3855]\n",
+ "Hit: 1, Rank: 1, NDCG@5: 1.0000\n",
+ "\n",
+ "Processing query 3/5: 'perforation of the esophagus' (Expected ID: '16A')\n",
+ "Encoding query: 'Clinical case about perforation of the esophagus'\n",
+ "Retrieved 5 cases with similarity scores:\n",
+ "- Esophageal Perforation: 0.6348\n",
+ "- Esophageal Cancer: 0.5415\n",
+ "- Esophageal Dysmotility: 0.5269\n",
+ "- ARDS: 0.4708\n",
+ "- Esophagus and Trachea Trauma Empyema: 0.4500\n",
+ "Retrieved IDs: ['16A', '14A', '15A', '65A', '71A']\n",
+ "Retrieved Scores: [0.6348, 0.5415, 0.5269, 0.4708, 0.45]\n",
+ "Hit: 1, Rank: 1, NDCG@5: 1.0000\n",
+ "\n",
+ "Processing query 4/5: 'papilloma of the breast' (Expected ID: '43A')\n",
+ "Encoding query: 'Clinical case about papilloma of the breast'\n",
+ "Retrieved 5 cases with similarity scores:\n",
+ "- Intraductal Papilloma: 0.7125\n",
+ "- Unknown: 0.4971\n",
+ "- Inflammatory Breast Cancer: 0.4842\n",
+ "- Breast Abscess: 0.4826\n",
+ "- Breast Cancer in Pregnancy: 0.4616\n",
+ "Retrieved IDs: ['43A', 'Unknown', '42A', '39A', '40A']\n",
+ "Retrieved Scores: [0.7125, 0.4971, 0.4842, 0.4826, 0.4616]\n",
+ "Hit: 1, Rank: 1, NDCG@5: 1.0000\n",
+ "\n",
+ "Processing query 5/5: 'injury to the neck vessel' (Expected ID: '66A')\n",
+ "Encoding query: 'Clinical case about injury to the neck vessel'\n",
+ "Retrieved 5 cases with similarity scores:\n",
+ "- Vascular Neck Injury: 0.6604\n",
+ "- Esophagus and Trachea Trauma Empyema: 0.4970\n",
+ "- Liver Trauma: 0.4957\n",
+ "- Carotid Stenosis: 0.4798\n",
+ "- Lower Extremity Vascular Injury: 0.4598\n",
+ "Retrieved IDs: ['66A', '71A', '72A', '79A', '68A']\n",
+ "Retrieved Scores: [0.6604, 0.497, 0.4957, 0.4798, 0.4598]\n",
+ "Hit: 1, Rank: 1, NDCG@5: 1.0000\n",
+ "\n",
+ "--- Overall Retrieval Results (k=5) --- \n",
+ "Average Hit@5: 1.0000\n",
+ "Average MRR: 1.0000\n",
+ "Average NDCG@5: 1.0000\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Define benchmark queries and corresponding gold standard case IDs\n",
+ "benchmark_queries = [\n",
+ " \"appendix inflammation in a child\",\n",
+ " \"Pyloric Stenosis in a child\",\n",
+ " \"perforation of the esophagus\",\n",
+ " \"papilloma of the breast\",\n",
+ " \"injury to the neck vessel\"\n",
+ "]\n",
+ "\n",
+ "benchmark_gold_ids = [\n",
+ " \"85A\", # Appendicitis Pediatrics\n",
+ " \"90A\", # Pyloric Stenosis Pediatrics\n",
+ " \"16A\", # Esophageal Perforation\n",
+ " \"43A\", # Intraductal Papilloma\n",
+ " \"66A\" # Vascular Neck Injury\n",
+ "]\n",
+ "\n",
+ "# Run benchmarking at k=5 using the initialized retriever\n",
+ "k_value = 5\n",
+ "results = retrieval_metrics(retriever, benchmark_queries, benchmark_gold_ids, k=k_value)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "compare-nonrag-markdown",
+ "metadata": {},
+ "source": [
+ "## 7. Compare with Non-RAG Simulation\n",
+ "\n",
+ "This section demonstrates generating a synthetic case directly with an LLM and running the simulation using the `DummyRetriever` for comparison. \n",
+ "We will use the same example (intussusception in pediatric patient')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "compare-nonrag-gen-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "**➡️ Please enter the topic that you would like to generate synthetic data about:**"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# --- Generate and Process Synthetic Case ---\n",
+ "display(Markdown(f\"**➡️ Please enter the topic that you would like to generate synthetic data about:**\"))\n",
+ "synthetic_query = input(\"Answer: \")\n",
+ "df_synthetic_case = pd.DataFrame()\n",
+ "\n",
+ "raw_synthetic_text = generate_synthetic_case(synthetic_query, model_id=EVALUATOR_MODEL_ID)\n",
+ "df_synthetic_case = process_synthetic_data(synthetic_query, raw_synthetic_text)\n",
+ "display(df_synthetic_case)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9698b316-528f-4324-a6a5-ca27b1407b12",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize components and starting case\n",
+ "dummy_retriever = DummyRetriever(df_synthetic_case)\n",
+ "non_rag_simulator = OralExamSimulator(dummy_retriever, evaluator)\n",
+ "non_rag_case_info = non_rag_simulator.start_new_case(clinical_query=synthetic_query)\n",
+ "print(f\"[Q1] {non_rag_case_info.get('current_question', 'N/A')}\")\n",
+ "\n",
+ "display(Markdown(f\"
Please enter your answer:\"))\n",
+ "non_rag_answer_1 = input(\"Answer :\")\n",
+ "non_rag_result_1 = non_rag_simulator.process_user_response(non_rag_answer_1)\n",
+ "print(f\"[F1] {non_rag_result_1.get('feedback', 'N/A')}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f5ec35c2-1bc0-4602-a194-52fa62acb94d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# --- Process Turn 2 ---\n",
+ "print(f\"[Q2] {non_rag_result_1.get('next_question', 'N/A')}\")\n",
+ "display(Markdown(f\"
Please enter your answer:\"))\n",
+ "non_rag_answer_2 = input(\"Answer :\")\n",
+ "non_rag_answer_2 = non_rag_simulator.process_user_response(non_rag_answer_2)\n",
+ "print(f\"[F2] {non_rag_result_2.get('feedback', 'N/A')}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "llm_course",
+ "language": "python",
+ "name": "llm_course"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/demo.ipynb b/notebooks/demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..71433ac3251eb65a4780e8fe6ad48c8c723326ff
--- /dev/null
+++ b/notebooks/demo.ipynb
@@ -0,0 +1,1182 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "intro-markdown",
+ "metadata": {},
+ "source": [
+ "# RAG Pipeline Demo: Surgery Oral Board Simulation\n",
+ "\n",
+ "This notebook demonstrates the usage of the RAG (Retrieval-Augmented Generation) pipeline developed for simulating surgery oral board scenarios. It utilizes the custom Python modules located in the `src/` directory.\n",
+ "\n",
+ "**Pipeline Components:**\n",
+ "* `data_processing`: Loads and preprocesses raw case data from `.docx` files.\n",
+ "* `ClinicalCaseProcessor`: Creates embeddings for clinical cases.\n",
+ "* `ClinicalCaseRetriever`: Retrieves relevant cases based on semantic similarity to a query.\n",
+ "* `AnswerEvaluator`: Uses an LLM to evaluate user responses against expected answers.\n",
+ "* `OralExamSimulator`: Orchestrates the simulation flow."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "setup-markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Setup\n",
+ "\n",
+ "Import necessary libraries and custom modules. Define constants and handle Hugging Face Hub authentication."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "setup-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-04-22 12:59:49.329544: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
+ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
+ "E0000 00:00:1745341189.621398 614864 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
+ "E0000 00:00:1745341189.717816 614864 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+ "W0000 00:00:1745341190.668036 614864 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
+ "W0000 00:00:1745341190.668086 614864 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
+ "W0000 00:00:1745341190.668088 614864 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
+ "W0000 00:00:1745341190.668090 614864 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
+ "2025-04-22 12:59:50.675395: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import json\n",
+ "import sys \n",
+ "import re\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from dotenv import load_dotenv\n",
+ "from huggingface_hub import login\n",
+ "from sklearn.metrics import ndcg_score\n",
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
+ "from IPython.display import display, Markdown\n",
+ "\n",
+ "# --- Add project root to sys.path ---\n",
+ "project_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))\n",
+ "if project_root not in sys.path:\n",
+ " sys.path.append(project_root)\n",
+ "\n",
+ "# --- Custom Module Imports ---\n",
+ "from src.data_processing import process_all_cases, ClinicalCaseProcessor\n",
+ "from src.retriever import ClinicalCaseRetriever, DummyRetriever\n",
+ "from src.evaluator import AnswerEvaluator\n",
+ "from src.simulator import OralExamSimulator\n",
+ "from src.evaluation_utils import retrieval_metrics\n",
+ "from src.synthetic_generator import generate_synthetic_case, process_synthetic_data\n",
+ "\n",
+ "# --- Configuration ---\n",
+ "DATA_FOLDER = \"../data/\"\n",
+ "PROCESSED_DATA_PATH = \"../processed_clinical_cases\"\n",
+ "EVALUATOR_MODEL_ID = \"meta-llama/Llama-3.2-3B-Instruct\"\n",
+ "EMBEDDING_MODEL_ID = \"all-MiniLM-L6-v2\"\n",
+ "\n",
+ "# Hugging Face Hub Authentication (using .env file in project root)\n",
+ "dotenv_path = os.path.join(project_root, 'hf_login.env')\n",
+ "load_dotenv(dotenv_path=dotenv_path)\n",
+ "hf_key = os.getenv(\"HF_KEY\")\n",
+ "\n",
+ "# --- Set Cache Directory (Optional) ---\n",
+ "cache_dir = \"/scratch/krb3ym/models/cache\" # Keep your specific path\n",
+ "os.environ['HF_HOME'] = cache_dir\n",
+ "os.makedirs(cache_dir, exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "load-data-markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Load and Preprocess Raw Data\n",
+ "\n",
+ "Load the clinical cases from the `.docx` files into a pandas DataFrame using the utility function from `src.data_processing`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "load-data-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing case files from: ../data/\n",
+ "Warning: Could not parse filename format: BTK_-_45A__Breast_Cancer.docx\n",
+ "Loaded 99 cases. Displaying head:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " case_id | \n",
+ " clinical_presentation | \n",
+ " turn_id | \n",
+ " question | \n",
+ " answer | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 85A | \n",
+ " Appendicitis Pediatrics | \n",
+ " 1 | \n",
+ " You were called to the emergency department to... | \n",
+ " Okay, well, I would begin by evaluating the pa... | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 85A | \n",
+ " Appendicitis Pediatrics | \n",
+ " 2 | \n",
+ " The patient is febrile at 101.2 degrees, heart... | \n",
+ " Okay, does he have a positive psoas or a Rovsi... | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 85A | \n",
+ " Appendicitis Pediatrics | \n",
+ " 3 | \n",
+ " Yes, both are positive. | \n",
+ " I'm concerned about acute appendicitis, but th... | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 85A | \n",
+ " Appendicitis Pediatrics | \n",
+ " 4 | \n",
+ " Labs are in order. White blood cell count of 1... | \n",
+ " Well, this is diagnostic of acute appendicitis... | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 85A | \n",
+ " Appendicitis Pediatrics | \n",
+ " 5 | \n",
+ " Are there any other therapeutic options for th... | \n",
+ " Yes, the patient has early acute appendicitis ... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " case_id clinical_presentation turn_id \\\n",
+ "0 85A Appendicitis Pediatrics 1 \n",
+ "1 85A Appendicitis Pediatrics 2 \n",
+ "2 85A Appendicitis Pediatrics 3 \n",
+ "3 85A Appendicitis Pediatrics 4 \n",
+ "4 85A Appendicitis Pediatrics 5 \n",
+ "\n",
+ " question \\\n",
+ "0 You were called to the emergency department to... \n",
+ "1 The patient is febrile at 101.2 degrees, heart... \n",
+ "2 Yes, both are positive. \n",
+ "3 Labs are in order. White blood cell count of 1... \n",
+ "4 Are there any other therapeutic options for th... \n",
+ "\n",
+ " answer \n",
+ "0 Okay, well, I would begin by evaluating the pa... \n",
+ "1 Okay, does he have a positive psoas or a Rovsi... \n",
+ "2 I'm concerned about acute appendicitis, but th... \n",
+ "3 Well, this is diagnostic of acute appendicitis... \n",
+ "4 Yes, the patient has early acute appendicitis ... "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "dataframe_btk = process_all_cases(DATA_FOLDER)\n",
+ "print(f\"Loaded {dataframe_btk['case_id'].nunique()} cases. Displaying head:\")\n",
+ "display(dataframe_btk.head())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "prepare-rag-markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Prepare RAG Dataset with Embeddings\n",
+ "\n",
+ "Use the `ClinicalCaseProcessor` to process the DataFrame, generate embeddings for each case summary, and save the result as a Hugging Face Dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "prepare-rag-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Initializing ClinicalCaseProcessor with model: all-MiniLM-L6-v2\n",
+ "Using device: cuda\n",
+ "Using provided DataFrame.\n",
+ "Raw data shape: (973, 5)\n",
+ "Grouping data by case...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Processing Cases: 100%|██████████| 99/99 [00:00<00:00, 6184.54it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processed data into 99 unique cases.\n",
+ "Generating embeddings for 99 case summaries...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Embedding Batches: 100%|██████████| 7/7 [00:00<00:00, 11.19it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated embeddings with shape: (99, 384)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ef6d95d464ea4babb9adeb1ff9b16582",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Saving the dataset (0/1 shards): 0%| | 0/99 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processed dataset saved successfully to ../processed_clinical_cases\n",
+ "Dataset({\n",
+ " features: ['case_id', 'clinical_presentation', 'questions', 'answers', 'case_summary', 'embeddings'],\n",
+ " num_rows: 99\n",
+ "})\n"
+ ]
+ }
+ ],
+ "source": [
+ "processor = ClinicalCaseProcessor(model_name=EMBEDDING_MODEL_ID)\n",
+ "processed_dataset = processor.preprocess_data(dataframe_btk, output_path=PROCESSED_DATA_PATH)\n",
+ "print(processed_dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "init-pipeline-markdown",
+ "metadata": {},
+ "source": [
+ "## 4. Initialize RAG Pipeline Components\n",
+ "\n",
+ "Instantiate the retriever, evaluator, and simulator classes."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "init-pipeline-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Initializing ClinicalCaseRetriever with model: all-MiniLM-L6-v2\n",
+ "Dataset loaded successfully from disk: ../processed_clinical_cases\n",
+ "Dataset features: {'case_id': Value(dtype='string', id=None), 'clinical_presentation': Value(dtype='string', id=None), 'questions': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'answers': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'case_summary': Value(dtype='string', id=None), 'embeddings': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None)}\n",
+ "Number of cases in dataset: 99\n",
+ "Using device: cuda\n",
+ "Loaded 99 cases with embeddings of shape (99, 384)\n",
+ "Initializing AnswerEvaluator with model: meta-llama/Llama-3.2-3B-Instruct\n",
+ "Set pad_token to eos_token\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "02a5e6af243344489487ed8639da9db3",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "AnswerEvaluator model loaded successfully on device: cuda:0\n"
+ ]
+ }
+ ],
+ "source": [
+ "retriever = ClinicalCaseRetriever(dataset_path=PROCESSED_DATA_PATH, model_name=EMBEDDING_MODEL_ID)\n",
+ "evaluator = AnswerEvaluator(model_id=EVALUATOR_MODEL_ID)\n",
+ "simulator = OralExamSimulator(retriever, evaluator)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "run-sim-markdown",
+ "metadata": {},
+ "source": [
+ "## 5. Run Simulation with RAG\n",
+ "\n",
+ "Demonstrate the core simulation loop: start a case based on a query, process user responses, and get feedback. \n",
+ "Let's take the example of 'bowel intussusception in a child'."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "run-sim-start-markdown",
+ "metadata": {},
+ "source": [
+ "### 5.1 Start a New Case"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "run-sim-start-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "**➡️ Please enter the clinical scenario that you would like to be examined on:**"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdin",
+ "output_type": "stream",
+ "text": [
+ " Your Answer: bowel intussusception in child\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "--------------------------------------------------\n",
+ "Attempting to start new case | Query: 'bowel intussusception in child' | Index: None\n",
+ "Encoding query: 'Clinical case about bowel intussusception in child'\n",
+ "Retrieved 1 cases with similarity scores:\n",
+ "- Intussusception Pediatrics: 0.7762\n",
+ "Retrieved case via query ('bowel intussusception in child') with score 0.7762: Intussusception Pediatrics\n",
+ "Case successfully started. Total questions: 12\n",
+ "--------------------------------------------------\n",
+ "\n",
+ "Case Started: Intussusception Pediatrics\n",
+ "Total Questions: 12\n",
+ "\n",
+ "--- Question 1 --- \n",
+ "You're called to the emergency department to evaluate an 18-month-old boy with an eight-hour history of intermittent intense abdominal pain. His mother notes that the pain starts abruptly, and that the kid will draw his knees up to his chest during these painful episodes. Pain lasts for 30 minutes at a time and spontaneously resolves. He's had one episode of non-bloody emesis. He's not had diarrhea or bloody stools.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# --- Start Simulation ---\n",
+ "display(Markdown(f\"**➡️ Please enter the clinical scenario that you would like to be examined on:**\"))\n",
+ "query = input(\" Your Answer: \")\n",
+ "case_info = simulator.start_new_case(clinical_query=query)\n",
+ "\n",
+ "# Directly print the first question (assuming success)\n",
+ "print(f\"\\nCase Started: {case_info.get('clinical_presentation', 'N/A')}\")\n",
+ "print(f\"Total Questions: {case_info.get('total_questions', 'N/A')}\")\n",
+ "print(\"\\n--- Question 1 --- \")\n",
+ "print(case_info.get('current_question', 'Error: No question found'))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "run-sim-turn1-markdown",
+ "metadata": {},
+ "source": [
+ "### 5.2 Process User Response (Turn 1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "44c6e822-5508-4db2-8700-79f25dc14e05",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "**➡️ Your Turn (Question 1/12)**"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdin",
+ "output_type": "stream",
+ "text": [
+ " Your Answer: Perform physical exam, looking specifically for abdominal signs, order labs, and get vitals\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "⏳ Processing User Answer...\n",
+ "--------------------------------------------------\n",
+ "Processing response for Question 1/12\n",
+ "User Response: Perform physical exam, looking specifically for abdominal signs, order labs, and get vitals\n",
+ "Expected Answer: I'd begin by obtaining vital signs and performing a comprehensive history and physical examination, focusing on his history, birth history, and my exam focusing on his abdominal exam, looking for hernias, and doing a rectal exam as well.\n",
+ "Generated Feedback: The resident’s response is: \n",
+ "Perform physical exam, looking specifically for abdominal signs, order labs, and get vitals\n",
+ "\n",
+ "ASSESSMENT: Partially Correct\n",
+ "\n",
+ "The resident mentioned performing a physical exam and ordering labs, but missed the importance of getting vitals. The resident’s response is partially correct because they included some key points but omitted others, such as the importance of obtaining vital signs.\n",
+ "Next question (1/12): The patient's tachycardic, the rest of the vital signs are normal. Your exam reveals a toddler in the fetal position on the stretcher. He appears to be in moderate distress. He has diffuse abdominal tenderness and voluntary guarding.\n",
+ "--------------------------------------------------\n",
+ "\n",
+ "❓Next Question (1/12)\n",
+ "\n",
+ "The patient's tachycardic, the rest of the vital signs are normal. Your exam reveals a toddler in the fetal position on the stretcher. He appears to be in moderate distress. He has diffuse abdominal tenderness and voluntary guarding.\n"
+ ]
+ }
+ ],
+ "source": [
+ "display(Markdown(f\"**➡️ Your Turn (Question {simulator.current_question_idx + 1}/{len(simulator.current_case['questions'])})**\"))\n",
+ "user_answer = input(\" Your Answer: \")\n",
+ "\n",
+ "# --- Process Response ---\n",
+ "print(f\"\\n⏳ Processing User Answer...\")\n",
+ "result = simulator.process_user_response(user_answer)\n",
+ "\n",
+ "# Next question\n",
+ "q_num = result.get('question_number', '?')\n",
+ "next_q_text = result.get('next_question', '*Error: No next question found*')\n",
+ "print(f\"\\n❓Next Question ({q_num}/{len(simulator.current_case['questions'])})\\n\\n{next_q_text}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "run-sim-turn2-markdown",
+ "metadata": {},
+ "source": [
+ "### 5.3 Process User Response (Turn 2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "run-sim-turn2-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "**➡️ Your Turn (Question 2/12)**"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdin",
+ "output_type": "stream",
+ "text": [
+ " Your Answer: Ask for labs, get imaging\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "⏳ Processing User Answer...\n",
+ "--------------------------------------------------\n",
+ "Processing response for Question 2/12\n",
+ "User Response: Ask for labs, get imaging\n",
+ "Expected Answer: I'd like to know a little bit more about the history. Has he recently been ill or had any sick contacts? When did he pass stool or gas last? Is this a first such episode? Is there any family history of GI conditions or malignancy?\n",
+ "Generated Feedback: The resident’s response is: \n",
+ "I'd like to know a little bit more about the patient's history. Has he recently been ill or had any sick contacts? When did he pass stool or gas last? Is this a first episode? Is there any family history of GI conditions or malignancy? Ask for labs, get imaging.\n",
+ "\n",
+ "The expected answer is: \n",
+ "I'd like to know a little bit more about the patient's history. Has he recently been ill or had any sick contacts? When did he pass stool or gas last? Is this a first episode? Is there any family history of GI conditions or malignancy? Ask for labs, get imaging.\n",
+ "\n",
+ "The resident’s response is identical to the expected answer. The resident includes all the\n",
+ "Next question (2/12): He has had a mild upper respiratory infection for the past week, but he's been eating normally. His last bowel movement was yesterday. This is the first time he's ever had symptoms like this, and his family and personal history is otherwise unremarkable.\n",
+ "--------------------------------------------------\n",
+ "\n",
+ "❓ Next Question (2/12)\n",
+ "\n",
+ "He has had a mild upper respiratory infection for the past week, but he's been eating normally. His last bowel movement was yesterday. This is the first time he's ever had symptoms like this, and his family and personal history is otherwise unremarkable.\n"
+ ]
+ }
+ ],
+ "source": [
+ "display(Markdown(f\"**➡️ Your Turn (Question {simulator.current_question_idx + 1}/{len(simulator.current_case['questions'])})**\"))\n",
+ "user_answer = input(\" Your Answer: \")\n",
+ "\n",
+ "print(f\"\\n⏳ Processing User Answer...\")\n",
+ "result = simulator.process_user_response(user_answer)\n",
+ "\n",
+ "q_num = result.get('question_number', '?')\n",
+ "next_q_text = result.get('next_question', '*Error: No next question found*')\n",
+ "print(f\"\\n❓ Next Question ({q_num}/{len(simulator.current_case['questions'])})\\n\\n{next_q_text}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "run-sim-summary-markdown",
+ "metadata": {},
+ "source": [
+ "### 5.4 Session Summary and Saving"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "run-sim-summary-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "--- Generating Session Summary --- \n",
+ "Case: Intussusception Pediatrics\n",
+ "Case ID: 87A\n",
+ "Total Questions in Case: 12\n",
+ "Number of Interactions Logged: 8\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"\\n--- Generating Session Summary --- \")\n",
+ "session_summary = simulator.generate_session_summary()\n",
+ "\n",
+ "print(f\"Case: {session_summary.get('case', 'N/A')}\")\n",
+ "print(f\"Case ID: {session_summary.get('case_id', 'N/A')}\")\n",
+ "print(f\"Total Questions in Case: {session_summary.get('total_questions_in_case', 'N/A')}\")\n",
+ "print(f\"Number of Interactions Logged: {len(session_summary.get('interaction_history', []))}\")\n",
+ "\n",
+ "# Display full session interaction\n",
+ "#print(json.dumps(session_summary, indent=2)) # Keep commented out or remove"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "eval-retrieval-markdown",
+ "metadata": {},
+ "source": [
+ "## 6. Evaluate Retrieval Performance\n",
+ "\n",
+ "Creating dummy dataset to evaluate RAG performance"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "eval-retrieval-run-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Calculating retrieval metrics for 5 queries (k=5)...\n",
+ "\n",
+ "Processing query 1/5: 'appendix inflammation in a child' (Expected ID: '85A')\n",
+ "Encoding query: 'Clinical case about appendix inflammation in a child'\n",
+ "Retrieved 5 cases with similarity scores:\n",
+ "- Appendicitis Pediatrics: 0.7076\n",
+ "- Acute Appendicitis: 0.5978\n",
+ "- Intussusception Pediatrics: 0.5745\n",
+ "- Meckel s Diverticulum Pediatrics: 0.5277\n",
+ "- Abdominal Mass Pediatrics: 0.4705\n",
+ "Retrieved IDs: ['85A', '31A', '87A', '89A', '84A']\n",
+ "Retrieved Scores: [0.7076, 0.5978, 0.5745, 0.5277, 0.4705]\n",
+ "Hit: 1, Rank: 1, NDCG@5: 1.0000\n",
+ "\n",
+ "Processing query 2/5: 'Pyloric Stenosis in a child' (Expected ID: '90A')\n",
+ "Encoding query: 'Clinical case about Pyloric Stenosis in a child'\n",
+ "Retrieved 5 cases with similarity scores:\n",
+ "- Pyloric Stenosis Pediatrics: 0.7554\n",
+ "- Carotid Stenosis: 0.4458\n",
+ "- Pheochromocytoma: 0.4139\n",
+ "- Small Bowel Obstruction: 0.3876\n",
+ "- Intussusception Pediatrics: 0.3855\n",
+ "Retrieved IDs: ['90A', '79A', '54A', '23A', '87A']\n",
+ "Retrieved Scores: [0.7554, 0.4458, 0.4139, 0.3876, 0.3855]\n",
+ "Hit: 1, Rank: 1, NDCG@5: 1.0000\n",
+ "\n",
+ "Processing query 3/5: 'perforation of the esophagus' (Expected ID: '16A')\n",
+ "Encoding query: 'Clinical case about perforation of the esophagus'\n",
+ "Retrieved 5 cases with similarity scores:\n",
+ "- Esophageal Perforation: 0.6348\n",
+ "- Esophageal Cancer: 0.5415\n",
+ "- Esophageal Dysmotility: 0.5269\n",
+ "- ARDS: 0.4708\n",
+ "- Esophagus and Trachea Trauma Empyema: 0.4500\n",
+ "Retrieved IDs: ['16A', '14A', '15A', '65A', '71A']\n",
+ "Retrieved Scores: [0.6348, 0.5415, 0.5269, 0.4708, 0.45]\n",
+ "Hit: 1, Rank: 1, NDCG@5: 1.0000\n",
+ "\n",
+ "Processing query 4/5: 'papilloma of the breast' (Expected ID: '43A')\n",
+ "Encoding query: 'Clinical case about papilloma of the breast'\n",
+ "Retrieved 5 cases with similarity scores:\n",
+ "- Intraductal Papilloma: 0.7125\n",
+ "- Unknown: 0.4971\n",
+ "- Inflammatory Breast Cancer: 0.4842\n",
+ "- Breast Abscess: 0.4826\n",
+ "- Breast Cancer in Pregnancy: 0.4616\n",
+ "Retrieved IDs: ['43A', 'Unknown', '42A', '39A', '40A']\n",
+ "Retrieved Scores: [0.7125, 0.4971, 0.4842, 0.4826, 0.4616]\n",
+ "Hit: 1, Rank: 1, NDCG@5: 1.0000\n",
+ "\n",
+ "Processing query 5/5: 'injury to the neck vessel' (Expected ID: '66A')\n",
+ "Encoding query: 'Clinical case about injury to the neck vessel'\n",
+ "Retrieved 5 cases with similarity scores:\n",
+ "- Vascular Neck Injury: 0.6604\n",
+ "- Esophagus and Trachea Trauma Empyema: 0.4970\n",
+ "- Liver Trauma: 0.4957\n",
+ "- Carotid Stenosis: 0.4798\n",
+ "- Lower Extremity Vascular Injury: 0.4598\n",
+ "Retrieved IDs: ['66A', '71A', '72A', '79A', '68A']\n",
+ "Retrieved Scores: [0.6604, 0.497, 0.4957, 0.4798, 0.4598]\n",
+ "Hit: 1, Rank: 1, NDCG@5: 1.0000\n",
+ "\n",
+ "--- Overall Retrieval Results (k=5) --- \n",
+ "Average Hit@5: 1.0000\n",
+ "Average MRR: 1.0000\n",
+ "Average NDCG@5: 1.0000\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Define benchmark queries and corresponding gold standard case IDs\n",
+ "benchmark_queries = [\n",
+ " \"appendix inflammation in a child\",\n",
+ " \"Pyloric Stenosis in a child\",\n",
+ " \"perforation of the esophagus\",\n",
+ " \"papilloma of the breast\",\n",
+ " \"injury to the neck vessel\"\n",
+ "]\n",
+ "\n",
+ "benchmark_gold_ids = [\n",
+ " \"85A\", # Appendicitis Pediatrics\n",
+ " \"90A\", # Pyloric Stenosis Pediatrics\n",
+ " \"16A\", # Esophageal Perforation\n",
+ " \"43A\", # Intraductal Papilloma\n",
+ " \"66A\" # Vascular Neck Injury\n",
+ "]\n",
+ "\n",
+ "# Run benchmarking at k=5 using the initialized retriever\n",
+ "k_value = 5\n",
+ "results = retrieval_metrics(retriever, benchmark_queries, benchmark_gold_ids, k=k_value)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "compare-nonrag-markdown",
+ "metadata": {},
+ "source": [
+ "## 7. Compare with Non-RAG Simulation\n",
+ "\n",
+ "This section demonstrates generating a synthetic case directly with an LLM and running the simulation using the `DummyRetriever` for comparison. \n",
+ "We will use the same example (intussusception in pediatric patient')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "compare-nonrag-gen-code",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "**➡️ Please enter the topic that you would like to generate synthetic data about:**"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdin",
+ "output_type": "stream",
+ "text": [
+ "Answer: bowel intussusception child\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generating synthetic case for 'bowel intussusception child' using meta-llama/Llama-3.2-3B-Instruct...\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c7501d4d2b714cb7b5c8bc7199138a85",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Synthetic case generation complete.\n",
+ "Processed synthetic data into DataFrame with 8 turns.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " case_id | \n",
+ " clinical_presentation | \n",
+ " turn_id | \n",
+ " question | \n",
+ " answer | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " SYNTH_01 | \n",
+ " bowel intussusception child | \n",
+ " 1 | \n",
+ " **\\nA 4-year-old boy presents to the emergency... | \n",
+ " Intussusception\\n\\n** | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " SYNTH_01 | \n",
+ " bowel intussusception child | \n",
+ " 2 | \n",
+ " **\\nWhat is the typical age range for intussus... | \n",
+ " 6-36 months\\n\\n** | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " SYNTH_01 | \n",
+ " bowel intussusception child | \n",
+ " 3 | \n",
+ " **\\nWhat is the primary mechanism of intussusc... | \n",
+ " Invagination of a lead point, often a tumor or... | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " SYNTH_01 | \n",
+ " bowel intussusception child | \n",
+ " 4 | \n",
+ " **\\nWhich of the following is a common finding... | \n",
+ " A palpable abdominal mass\\n\\n** | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " SYNTH_01 | \n",
+ " bowel intussusception child | \n",
+ " 5 | \n",
+ " **\\nWhat is the most common location for intus... | \n",
+ " Ileum\\n\\n** | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " SYNTH_01 | \n",
+ " bowel intussusception child | \n",
+ " 6 | \n",
+ " **\\nWhat is the primary treatment for intussus... | \n",
+ " Reduction with air enema\\n\\n** | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " SYNTH_01 | \n",
+ " bowel intussusception child | \n",
+ " 7 | \n",
+ " **\\nWhat is a potential complication of intuss... | \n",
+ " Perforation and peritonitis\\n\\n** | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " SYNTH_01 | \n",
+ " bowel intussusception child | \n",
+ " 8 | \n",
+ " **\\nHow often does intussusception recur after... | \n",
+ " Less than 5% of cases, with most recurrences o... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " case_id clinical_presentation turn_id \\\n",
+ "0 SYNTH_01 bowel intussusception child 1 \n",
+ "1 SYNTH_01 bowel intussusception child 2 \n",
+ "2 SYNTH_01 bowel intussusception child 3 \n",
+ "3 SYNTH_01 bowel intussusception child 4 \n",
+ "4 SYNTH_01 bowel intussusception child 5 \n",
+ "5 SYNTH_01 bowel intussusception child 6 \n",
+ "6 SYNTH_01 bowel intussusception child 7 \n",
+ "7 SYNTH_01 bowel intussusception child 8 \n",
+ "\n",
+ " question \\\n",
+ "0 **\\nA 4-year-old boy presents to the emergency... \n",
+ "1 **\\nWhat is the typical age range for intussus... \n",
+ "2 **\\nWhat is the primary mechanism of intussusc... \n",
+ "3 **\\nWhich of the following is a common finding... \n",
+ "4 **\\nWhat is the most common location for intus... \n",
+ "5 **\\nWhat is the primary treatment for intussus... \n",
+ "6 **\\nWhat is a potential complication of intuss... \n",
+ "7 **\\nHow often does intussusception recur after... \n",
+ "\n",
+ " answer \n",
+ "0 Intussusception\\n\\n** \n",
+ "1 6-36 months\\n\\n** \n",
+ "2 Invagination of a lead point, often a tumor or... \n",
+ "3 A palpable abdominal mass\\n\\n** \n",
+ "4 Ileum\\n\\n** \n",
+ "5 Reduction with air enema\\n\\n** \n",
+ "6 Perforation and peritonitis\\n\\n** \n",
+ "7 Less than 5% of cases, with most recurrences o... "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# --- Generate and Process Synthetic Case ---\n",
+ "display(Markdown(f\"**➡️ Please enter the topic that you would like to generate synthetic data about:**\"))\n",
+ "synthetic_query = input(\"Answer: \")\n",
+ "df_synthetic_case = pd.DataFrame()\n",
+ "\n",
+ "raw_synthetic_text = generate_synthetic_case(synthetic_query, model_id=EVALUATOR_MODEL_ID)\n",
+ "df_synthetic_case = process_synthetic_data(synthetic_query, raw_synthetic_text)\n",
+ "display(df_synthetic_case)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9698b316-528f-4324-a6a5-ca27b1407b12",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "DummyRetriever processing 1 unique presentations.\n",
+ "DummyRetriever initialized with 1 cases.\n",
+ "--------------------------------------------------\n",
+ "Attempting to start new case | Query: 'bowel intussusception child' | Index: None\n",
+ "DummyRetriever searching for exact match: 'bowel intussusception child'\n",
+ "DummyRetriever found match: bowel intussusception child\n",
+ "Retrieved case via query ('bowel intussusception child') with score 1.0000: bowel intussusception child\n",
+ "Case successfully started. Total questions: 8\n",
+ "--------------------------------------------------\n",
+ "[Q1] **\n",
+ "A 4-year-old boy presents to the emergency department with abdominal pain and vomiting. The child was playing outside and suddenly became ill, unable to continue playing. He has no history of abdominal pain or gastrointestinal symptoms.\n",
+ "\n",
+ "**Q1:**\n",
+ "What is the likely underlying cause of the child's symptoms?\n",
+ "A1: Intussusception\n",
+ "\n",
+ "**Q2:**\n",
+ "What is the typical age range for intussusception in children?\n",
+ "A2: 6-36 months\n",
+ "\n",
+ "**Q3:**\n",
+ "What is the primary mechanism of intussusception in children?\n",
+ "A3: Invagination of a lead point, often a tumor or inflammatory lesion, into the intestine\n",
+ "\n",
+ "**Q4:**\n",
+ "Which of the following is a common finding in children with intussusception?\n",
+ "A4: A palpable abdominal mass\n",
+ "\n",
+ "**Q5:**\n",
+ "What is the most common location for intussusception in children?\n",
+ "A5: Ileum\n",
+ "\n",
+ "**Q6:**\n",
+ "What is the primary treatment for intussusception in children?\n",
+ "A6: Reduction with air enema\n",
+ "\n",
+ "**Q7:**\n",
+ "What is a potential complication of intussusception if not treated promptly?\n",
+ "A7: Perforation and peritonitis\n",
+ "\n",
+ "**Q8:**\n",
+ "How often does intussusception recur after successful treatment?\n",
+ "A8: Less than 5% of cases, with most recurrences occurring within 6 months of initial presentation. **\n",
+ "What is the likely underlying cause of the child's symptoms?\n"
+ ]
+ },
+ {
+ "data": {
+ "text/markdown": [
+ "
Please enter your answer:"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Initialize components and starting case\n",
+ "dummy_retriever = DummyRetriever(df_synthetic_case)\n",
+ "non_rag_simulator = OralExamSimulator(dummy_retriever, evaluator)\n",
+ "non_rag_case_info = non_rag_simulator.start_new_case(clinical_query=synthetic_query)\n",
+ "print(f\"[Q1] {non_rag_case_info.get('current_question', 'N/A')}\")\n",
+ "\n",
+ "display(Markdown(f\"
Please enter your answer:\"))\n",
+ "non_rag_answer_1 = input(\"Answer :\")\n",
+ "non_rag_result_1 = non_rag_simulator.process_user_response(non_rag_answer_1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f5ec35c2-1bc0-4602-a194-52fa62acb94d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# --- Process Turn 2 ---\n",
+ "print(f\"[Q2] {non_rag_result_1.get('next_question', 'N/A')}\")\n",
+ "display(Markdown(f\"
Please enter your answer:\"))\n",
+ "non_rag_answer_2 = input(\"Answer :\")\n",
+ "non_rag_answer_2 = non_rag_simulator.process_user_response(non_rag_answer_2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "056083b9-8181-44f6-acdc-65aae68d3243",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "os.chdir(\"../\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "acac5513-2a0a-4adf-8cff-f9818e024153",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'/sfs/gpfs/tardis/home/krb3ym/Documents/MSDS/DS5002 - LLMs/Final Project/huggingface-repo'"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pwd()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "588dc01a-2da6-4c59-8726-ea798279d6c8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from huggingface_hub import HfApi\n",
+ "api = HfApi()\n",
+ "\n",
+ "api.upload_folder(\n",
+ " folder_path=\"/path/to/local/space\",\n",
+ " repo_id=\"username/my-cool-space\",\n",
+ " repo_type=\"space\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "036f6fd9-5826-4af3-b5a8-b255ecc978c8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/krb3ym/.local/lib/python3.11/site-packages/huggingface_hub/hf_api.py:3987: UserWarning: It seems that you are about to commit a data file (processed_clinical_cases/data-00000-of-00001.arrow) to a space repository. You are sure this is intended? If you are trying to upload a dataset, please set `repo_type='dataset'` or `--repo-type=dataset` in a CLI.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "ename": "RepositoryNotFoundError",
+ "evalue": "404 Client Error. (Request ID: Root=1-6807fa93-154f77b15a7a81585b63e8b4;bc6c900e-cf73-4777-be1e-f96c6c056daa)\n\nRepository Not Found for url: https://huggingface.co/api/spaces/melmoheb/boardgpt-llm/preupload/main.\nPlease make sure you specified the correct `repo_id` and `repo_type`.\nIf you are trying to access a private or gated repo, make sure you are authenticated.\nNote: Creating a commit assumes that the repo already exists on the Huggingface Hub. Please use `create_repo` if it's not the case.",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mHTTPError\u001b[0m Traceback (most recent call last)",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/huggingface_hub/utils/_http.py:406\u001b[0m, in \u001b[0;36mhf_raise_for_status\u001b[0;34m(response, endpoint_name)\u001b[0m\n\u001b[1;32m 405\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 406\u001b[0m \u001b[43mresponse\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 407\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m HTTPError \u001b[38;5;28;01mas\u001b[39;00m e:\n",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/requests/models.py:1024\u001b[0m, in \u001b[0;36mResponse.raise_for_status\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m http_error_msg:\n\u001b[0;32m-> 1024\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m HTTPError(http_error_msg, response\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m)\n",
+ "\u001b[0;31mHTTPError\u001b[0m: 404 Client Error: Not Found for url: https://huggingface.co/api/spaces/melmoheb/boardgpt-llm/preupload/main",
+ "\nThe above exception was the direct cause of the following exception:\n",
+ "\u001b[0;31mRepositoryNotFoundError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[14], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m \n\u001b[1;32m 4\u001b[0m api \u001b[38;5;241m=\u001b[39m HfApi(token \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mgetenv(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mHF_TOKEN\u001b[39m\u001b[38;5;124m'\u001b[39m))\n\u001b[0;32m----> 5\u001b[0m \u001b[43mapi\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupload_folder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mfolder_path\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m.\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmelmoheb/boardgpt-llm\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mspace\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:114\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m 112\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/huggingface_hub/hf_api.py:1559\u001b[0m, in \u001b[0;36mfuture_compatible.._inner\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1556\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_as_future(fn, \u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# Otherwise, call the function normally\u001b[39;00m\n\u001b[0;32m-> 1559\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/huggingface_hub/hf_api.py:5105\u001b[0m, in \u001b[0;36mHfApi.upload_folder\u001b[0;34m(self, repo_id, folder_path, path_in_repo, commit_message, commit_description, token, repo_type, revision, create_pr, parent_commit, allow_patterns, ignore_patterns, delete_patterns, multi_commits, multi_commits_verbose, run_as_future)\u001b[0m\n\u001b[1;32m 5101\u001b[0m \u001b[38;5;66;03m# Defining a CommitInfo object is not really relevant in this case\u001b[39;00m\n\u001b[1;32m 5102\u001b[0m \u001b[38;5;66;03m# Let's return early with pr_url only (as string).\u001b[39;00m\n\u001b[1;32m 5103\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m pr_url\n\u001b[0;32m-> 5105\u001b[0m commit_info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_commit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 5106\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5107\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5108\u001b[0m \u001b[43m \u001b[49m\u001b[43moperations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_operations\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5109\u001b[0m \u001b[43m \u001b[49m\u001b[43mcommit_message\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_message\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5110\u001b[0m \u001b[43m \u001b[49m\u001b[43mcommit_description\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_description\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5111\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5112\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5113\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_pr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcreate_pr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5114\u001b[0m \u001b[43m \u001b[49m\u001b[43mparent_commit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparent_commit\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5115\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5117\u001b[0m \u001b[38;5;66;03m# Create url to uploaded folder (for legacy return value)\u001b[39;00m\n\u001b[1;32m 5118\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m create_pr \u001b[38;5;129;01mand\u001b[39;00m commit_info\u001b[38;5;241m.\u001b[39mpr_url \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:114\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m 112\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/huggingface_hub/hf_api.py:1559\u001b[0m, in \u001b[0;36mfuture_compatible.._inner\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1556\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_as_future(fn, \u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# Otherwise, call the function normally\u001b[39;00m\n\u001b[0;32m-> 1559\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/huggingface_hub/hf_api.py:4011\u001b[0m, in \u001b[0;36mHfApi.create_commit\u001b[0;34m(self, repo_id, operations, commit_message, commit_description, token, repo_type, revision, create_pr, num_threads, parent_commit, run_as_future)\u001b[0m\n\u001b[1;32m 4008\u001b[0m \u001b[38;5;66;03m# If updating twice the same file or update then delete a file in a single commit\u001b[39;00m\n\u001b[1;32m 4009\u001b[0m _warn_on_overwriting_operations(operations)\n\u001b[0;32m-> 4011\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpreupload_lfs_files\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4012\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4013\u001b[0m \u001b[43m \u001b[49m\u001b[43madditions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43madditions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4014\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4015\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4016\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43munquoted_revision\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# first-class methods take unquoted revision\u001b[39;49;00m\n\u001b[1;32m 4017\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_pr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcreate_pr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4018\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_threads\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_threads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4019\u001b[0m \u001b[43m \u001b[49m\u001b[43mfree_memory\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# do not remove `CommitOperationAdd.path_or_fileobj` on LFS files for \"normal\" users\u001b[39;49;00m\n\u001b[1;32m 4020\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4022\u001b[0m \u001b[38;5;66;03m# Remove no-op operations (files that have not changed)\u001b[39;00m\n\u001b[1;32m 4023\u001b[0m operations_without_no_op \u001b[38;5;241m=\u001b[39m []\n",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/huggingface_hub/hf_api.py:4530\u001b[0m, in \u001b[0;36mHfApi.preupload_lfs_files\u001b[0;34m(self, repo_id, additions, token, repo_type, revision, create_pr, num_threads, free_memory, gitignore_content)\u001b[0m\n\u001b[1;32m 4528\u001b[0m \u001b[38;5;66;03m# Check which new files are LFS\u001b[39;00m\n\u001b[1;32m 4529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 4530\u001b[0m \u001b[43m_fetch_upload_modes\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4531\u001b[0m \u001b[43m \u001b[49m\u001b[43madditions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnew_additions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4532\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4533\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4534\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4535\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4536\u001b[0m \u001b[43m \u001b[49m\u001b[43mendpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mendpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4537\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_pr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcreate_pr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 4538\u001b[0m \u001b[43m \u001b[49m\u001b[43mgitignore_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgitignore_content\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4539\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4540\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m RepositoryNotFoundError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 4541\u001b[0m e\u001b[38;5;241m.\u001b[39mappend_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE)\n",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:114\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m 112\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/huggingface_hub/_commit_api.py:536\u001b[0m, in \u001b[0;36m_fetch_upload_modes\u001b[0;34m(additions, repo_type, repo_id, headers, revision, endpoint, create_pr, gitignore_content)\u001b[0m\n\u001b[1;32m 528\u001b[0m payload[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgitIgnore\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m gitignore_content\n\u001b[1;32m 530\u001b[0m resp \u001b[38;5;241m=\u001b[39m get_session()\u001b[38;5;241m.\u001b[39mpost(\n\u001b[1;32m 531\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mendpoint\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/api/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrepo_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124ms/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrepo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/preupload/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrevision\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 532\u001b[0m json\u001b[38;5;241m=\u001b[39mpayload,\n\u001b[1;32m 533\u001b[0m headers\u001b[38;5;241m=\u001b[39mheaders,\n\u001b[1;32m 534\u001b[0m params\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcreate_pr\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m1\u001b[39m\u001b[38;5;124m\"\u001b[39m} \u001b[38;5;28;01mif\u001b[39;00m create_pr \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 535\u001b[0m )\n\u001b[0;32m--> 536\u001b[0m \u001b[43mhf_raise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 537\u001b[0m preupload_info \u001b[38;5;241m=\u001b[39m _validate_preupload_info(resp\u001b[38;5;241m.\u001b[39mjson())\n\u001b[1;32m 538\u001b[0m upload_modes\u001b[38;5;241m.\u001b[39mupdate(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m{file[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpath\u001b[39m\u001b[38;5;124m\"\u001b[39m]: file[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muploadMode\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m preupload_info[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfiles\u001b[39m\u001b[38;5;124m\"\u001b[39m]})\n",
+ "File \u001b[0;32m~/.local/lib/python3.11/site-packages/huggingface_hub/utils/_http.py:454\u001b[0m, in \u001b[0;36mhf_raise_for_status\u001b[0;34m(response, endpoint_name)\u001b[0m\n\u001b[1;32m 435\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m error_code \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRepoNotFound\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 436\u001b[0m response\u001b[38;5;241m.\u001b[39mstatus_code \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m401\u001b[39m\n\u001b[1;32m 437\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m response\u001b[38;5;241m.\u001b[39mrequest \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 444\u001b[0m \u001b[38;5;66;03m# => for now, we process them as `RepoNotFound` anyway.\u001b[39;00m\n\u001b[1;32m 445\u001b[0m \u001b[38;5;66;03m# See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9\u001b[39;00m\n\u001b[1;32m 446\u001b[0m message \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 447\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresponse\u001b[38;5;241m.\u001b[39mstatus_code\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m Client Error.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 448\u001b[0m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m make sure you are authenticated.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 453\u001b[0m )\n\u001b[0;32m--> 454\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m _format(RepositoryNotFoundError, message, response) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 456\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m response\u001b[38;5;241m.\u001b[39mstatus_code \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m400\u001b[39m:\n\u001b[1;32m 457\u001b[0m message \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 458\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mBad request for \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mendpoint_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m endpoint:\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m endpoint_name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mBad request:\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 459\u001b[0m )\n",
+ "\u001b[0;31mRepositoryNotFoundError\u001b[0m: 404 Client Error. (Request ID: Root=1-6807fa93-154f77b15a7a81585b63e8b4;bc6c900e-cf73-4777-be1e-f96c6c056daa)\n\nRepository Not Found for url: https://huggingface.co/api/spaces/melmoheb/boardgpt-llm/preupload/main.\nPlease make sure you specified the correct `repo_id` and `repo_type`.\nIf you are trying to access a private or gated repo, make sure you are authenticated.\nNote: Creating a commit assumes that the repo already exists on the Huggingface Hub. Please use `create_repo` if it's not the case."
+ ]
+ }
+ ],
+ "source": [
+ "from huggingface_hub import HfApi\n",
+ "import os \n",
+ "\n",
+ "api = HfApi(token = os.getenv('HF_TOKEN'))\n",
+ "api.upload_folder(\n",
+ " folder_path = \".\",\n",
+ " repo_id = \"melmoheb/boardgpt-llm\", \n",
+ " repo_type = \"space\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7d7c7fd9-3edd-4535-99c6-031946a6b43c",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "llm_course",
+ "language": "python",
+ "name": "llm_course"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/processed_clinical_cases/data-00000-of-00001.arrow b/processed_clinical_cases/data-00000-of-00001.arrow
new file mode 100644
index 0000000000000000000000000000000000000000..2379b78af586b20d4d12e80944b08939b9d62b85
--- /dev/null
+++ b/processed_clinical_cases/data-00000-of-00001.arrow
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e565f3516eebd2348bb93e430a2c4d89794670984f4decccd8cc01dec3bcc459
+size 892112
diff --git a/processed_clinical_cases/dataset_info.json b/processed_clinical_cases/dataset_info.json
new file mode 100644
index 0000000000000000000000000000000000000000..2586d5301367f8d9fe8bf471512067dc8fe974fd
--- /dev/null
+++ b/processed_clinical_cases/dataset_info.json
@@ -0,0 +1,41 @@
+{
+ "citation": "",
+ "description": "",
+ "features": {
+ "case_id": {
+ "dtype": "string",
+ "_type": "Value"
+ },
+ "clinical_presentation": {
+ "dtype": "string",
+ "_type": "Value"
+ },
+ "questions": {
+ "feature": {
+ "dtype": "string",
+ "_type": "Value"
+ },
+ "_type": "Sequence"
+ },
+ "answers": {
+ "feature": {
+ "dtype": "string",
+ "_type": "Value"
+ },
+ "_type": "Sequence"
+ },
+ "case_summary": {
+ "dtype": "string",
+ "_type": "Value"
+ },
+ "embeddings": {
+ "feature": {
+ "dtype": "float64",
+ "_type": "Value"
+ },
+ "_type": "Sequence"
+ }
+ },
+ "homepage": "",
+ "license": ""
+}
\ No newline at end of file
diff --git a/processed_clinical_cases/state.json b/processed_clinical_cases/state.json
new file mode 100644
index 0000000000000000000000000000000000000000..f8b7749e338cd77c25a8a73ae0a3b5b6a68d0e9b
--- /dev/null
+++ b/processed_clinical_cases/state.json
@@ -0,0 +1,13 @@
+{
+ "_data_files": [
+ {
+ "filename": "data-00000-of-00001.arrow"
+ }
+ ],
+ "_fingerprint": "f41fcf52c584740a",
+ "_format_columns": null,
+ "_format_kwargs": {},
+ "_format_type": null,
+ "_output_all_columns": false,
+ "_split": null
+}
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9ae085aae9b0460b1a1ec4a12a6c6e17f0736c66
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+torch
+transformers
+pandas
+numpy
+scikit-learn
+datasets
+sentence-transformers
+python-docx
+python-dotenv
+huggingface_hub
+tqdm
\ No newline at end of file
diff --git a/src/.ipynb_checkpoints/data_processing-checkpoint.py b/src/.ipynb_checkpoints/data_processing-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..136d75448f617a50e50ff9de2bc72a915e5f6634
--- /dev/null
+++ b/src/.ipynb_checkpoints/data_processing-checkpoint.py
@@ -0,0 +1,201 @@
+import os
+import re
+import pandas as pd
+import numpy as np
+import torch
+from docx import Document
+from sentence_transformers import SentenceTransformer
+from datasets import Dataset
+from tqdm import tqdm
+
+def read_docx(file_path):
+ """Reads text content from a .docx file."""
+ try:
+ doc = Document(file_path)
+ return '\n'.join(para.text for para in doc.paragraphs)
+ except Exception as e:
+ print(f"Error reading {file_path}: {e}")
+ return ""
+
+def extract_qa_pairs(text):
+ """Extracts alternating Examiner and Examinee Q&A pairs from text."""
+ pattern = re.compile(r"\*\*Examiner:\*\*(.*?)\n\n\*\*Examinee:\*\*(.*?)(?=\n\n\*\*Examiner:\*\*|$)", re.DOTALL)
+ return [{"question": q.strip(), "answer": a.strip()} for q, a in pattern.findall(text)]
+
+def parse_filename(filename):
+ """Parses case ID and topic from BTK filename format."""
+ # Example: BTK_-_77A___Burn.docx -> case_id = 77A, clinical_presentation = Burn
+ base = os.path.splitext(filename)[0]
+ match = re.match(r"BTK_-_(\d+[A-Z]?)___(.+)", base)
+ if match:
+ case_id = match.group(1)
+ topic = match.group(2).replace("_", " ").strip()
+ else:
+ # Handle potential variations or log unknown formats if needed
+ print(f"Warning: Could not parse filename format: {filename}")
+ case_id, topic = "Unknown", "Unknown"
+ return case_id, topic
+
+def process_all_cases(folder_path):
+ """Reads all .docx files in a folder and structures them into a DataFrame."""
+ rows = []
+ if not os.path.isdir(folder_path):
+ print(f"Error: Folder not found at {folder_path}")
+ return pd.DataFrame(rows)
+
+ print(f"Processing case files from: {folder_path}")
+ for filename in os.listdir(folder_path):
+ if filename.lower().endswith('.docx') and not filename.startswith('~'): # Avoid temp files
+ file_path = os.path.join(folder_path, filename)
+ text = read_docx(file_path)
+ if text:
+ qa_pairs = extract_qa_pairs(text)
+ case_id, presentation = parse_filename(filename)
+ if not qa_pairs:
+ print(f"Warning: No Q&A pairs extracted from {filename}")
+ for i, pair in enumerate(qa_pairs):
+ rows.append({
+ "case_id": case_id,
+ "clinical_presentation": presentation,
+ "turn_id": i + 1,
+ "question": pair["question"],
+ "answer": pair["answer"]
+ })
+ else:
+ print(f"Warning: Empty content for file {filename}")
+
+ if not rows:
+ print("Warning: No data rows were generated. Check input files and formats.")
+
+ return pd.DataFrame(rows)
+
+
+# --- ClinicalCaseProcessor Class ---
+class ClinicalCaseProcessor:
+ """Handles preprocessing of clinical cases for the RAG system using sentence-transformers."""
+ def __init__(self, model_name="all-MiniLM-L6-v2"):
+ print(f"Initializing ClinicalCaseProcessor with model: {model_name}")
+ self.model = SentenceTransformer(model_name)
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print(f"Using device: {self.device}")
+ self.model.to(self.device)
+
+ def preprocess_data(self, input_data, output_path="./processed_clinical_cases", batch_size=16):
+ """
+ Convert raw case data (DataFrame or path to CSV) into a vectorized Hugging Face dataset.
+
+ Args:
+ input_data: DataFrame or path to CSV file with clinical cases.
+ output_path: Where to save the processed Hugging Face dataset.
+ batch_size: Batch size for embedding generation.
+
+ Returns:
+ datasets.Dataset: The processed dataset with embeddings.
+ """
+ # Load data
+ if isinstance(input_data, pd.DataFrame):
+ df = input_data
+ print("Using provided DataFrame.")
+ elif isinstance(input_data, str) and os.path.exists(input_data):
+ try:
+ df = pd.read_csv(input_data)
+ print(f"Data loaded from CSV: {input_data}")
+ except Exception as e:
+ print(f"Error loading CSV {input_data}: {e}")
+ return None
+ else:
+ print(f"Error: Invalid input_data type or path does not exist: {input_data}")
+ return None
+
+ if df.empty:
+ print("Error: Input DataFrame is empty. Cannot process.")
+ return None
+
+ print(f"Raw data shape: {df.shape}")
+
+ # Validate necessary columns
+ required_cols = ['case_id', 'clinical_presentation', 'turn_id', 'question', 'answer']
+ if not all(col in df.columns for col in required_cols):
+ print(f"Error: DataFrame missing required columns. Found: {df.columns}. Required: {required_cols}")
+ return None
+
+ # Group by case_id to get all Q&A pairs for each case
+ grouped = df.groupby(['case_id', 'clinical_presentation'], dropna=False)
+
+ # Create a new dataframe with one row per case
+ case_data = []
+ print("Grouping data by case...")
+ for (case_id, presentation), group in tqdm(grouped, desc="Processing Cases"):
+ # Sort by turn_id to ensure correct order
+ group = group.sort_values('turn_id')
+
+ # Extract questions and answers in order
+ questions = group['question'].tolist()
+ answers = group['answer'].tolist()
+
+ # Handle potential NaN/None in presentation if groupby didn't drop them
+ presentation_str = str(presentation) if pd.notna(presentation) else "Unknown Presentation"
+
+ case_data.append({
+ 'case_id': str(case_id) if pd.notna(case_id) else "Unknown ID",
+ 'clinical_presentation': presentation_str,
+ 'questions': questions,
+ 'answers': answers
+ })
+
+ if not case_data:
+ print("Error: No cases could be processed after grouping. Check input data integrity.")
+ return None
+
+ processed_df = pd.DataFrame(case_data)
+ print(f"Processed data into {len(processed_df)} unique cases.")
+
+ # Create a searchable summary of each case (handle empty question lists)
+ processed_df['case_summary'] = processed_df.apply(
+ lambda x: f"Clinical case: {x['clinical_presentation']}. First question: {x['questions'][0] if x['questions'] else 'No questions available'}",
+ axis=1
+ )
+
+ # Generate embeddings using sentence-transformers
+ texts_to_embed = processed_df['case_summary'].tolist()
+ all_embeddings = []
+
+ print(f"Generating embeddings for {len(texts_to_embed)} case summaries...")
+ try:
+ for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Batches"):
+ batch_texts = texts_to_embed[i:i+batch_size]
+ # Generate embeddings for the batch
+ batch_embeddings = self.model.encode(batch_texts, convert_to_numpy=True, device=self.device, show_progress_bar=False)
+ all_embeddings.append(batch_embeddings)
+
+ # Combine all batch embeddings
+ if not all_embeddings:
+ print("Error: No embeddings were generated.")
+ return None
+ final_embeddings = np.vstack(all_embeddings)
+ print(f"Generated embeddings with shape: {final_embeddings.shape}")
+
+ except Exception as e:
+ print(f"Error during embedding generation: {e}")
+ return None
+
+
+ # Convert to HF Dataset and add embeddings
+ try:
+ dataset = Dataset.from_pandas(processed_df)
+ # Ensure embeddings column is compatible (list of lists)
+ dataset = dataset.add_column('embeddings', final_embeddings.tolist())
+ except Exception as e:
+ print(f"Error converting to Hugging Face Dataset or adding embeddings: {e}")
+ return None
+
+ # Save processed dataset
+ try:
+ os.makedirs(output_path, exist_ok=True) # Ensure directory exists
+ dataset.save_to_disk(output_path)
+ print(f"Processed dataset saved successfully to {output_path}")
+ except Exception as e:
+ print(f"Error saving dataset to disk at {output_path}: {e}")
+ return None # Return None if saving failed
+
+ return dataset
\ No newline at end of file
diff --git a/src/.ipynb_checkpoints/evaluation_utils-checkpoint.py b/src/.ipynb_checkpoints/evaluation_utils-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..1463078095037aa4008b72fd68a57951ff57f3cf
--- /dev/null
+++ b/src/.ipynb_checkpoints/evaluation_utils-checkpoint.py
@@ -0,0 +1,67 @@
+import numpy as np
+from sklearn.metrics import ndcg_score
+from src.retriever import ClinicalCaseRetriever, DummyRetriever
+
+def retrieval_metrics(retriever_instance: ClinicalCaseRetriever, queries: list[str], gold_ids: list[str], k: int = 5) -> dict | None:
+ """
+ Calculates retrieval metrics for a set of queries.
+
+ Args:
+ retriever_instance: An initialized ClinicalCaseRetriever instance.
+ queries: A list of query strings.
+ gold_ids: A list of the expected 'case_id' strings for each query.
+ k: The number of top results to consider for Hit@k and NDCG@k.
+
+ Returns:
+ A dictionary containing Hit@k, MRR, and NDCG@k scores, or None on error.
+ """
+ # --- Initialization ---
+ hits, reciprocal_ranks, ndcgs = [], [], []
+ print(f"\nCalculating retrieval metrics for {len(queries)} queries (k={k})...")
+
+ # --- Process Each Query ---
+ for q_idx, (q, gold) in enumerate(zip(queries, gold_ids)):
+ print(f"\nProcessing query {q_idx+1}/{len(queries)}: '{q}' (Expected ID: '{gold}')")
+ retrieved_cases, scores = retriever_instance.retrieve_relevant_case(q, top_k=k, return_scores=True)
+
+ # Safely extract IDs, handle missing keys
+ retrieved_ids = [c.get('case_id', 'N/A') for c in retrieved_cases]
+ print(f"Retrieved IDs: {retrieved_ids}")
+ print(f"Retrieved Scores: {[round(s, 4) for s in scores]}")
+
+ # --- Calculate Metrics ---
+ is_hit = int(gold in retrieved_ids)
+ hits.append(is_hit)
+
+ rank = 0
+ if is_hit:
+ rank = retrieved_ids.index(gold) + 1
+ reciprocal_ranks.append(1.0 / rank)
+ else:
+ reciprocal_ranks.append(0.0)
+
+ # NDCG calculation
+ true_relevance = np.asarray([[1.0 if gid == gold else 0.0 for gid in retrieved_ids]])
+ predicted_scores = np.asarray([scores])
+
+ current_ndcg = 0.0
+ if true_relevance.shape[1] > 0:
+ ndcg_k = min(k, true_relevance.shape[1]) # Ensure k is not out of bounds
+ current_ndcg = ndcg_score(true_relevance, predicted_scores, k=ndcg_k)
+ ndcgs.append(current_ndcg)
+
+ print(f"Hit: {is_hit}, Rank: {rank if rank > 0 else 'N/A'}, NDCG@{k}: {current_ndcg:.4f}")
+
+ # --- Aggregate Results ---
+ avg_hit = np.mean(hits) if hits else 0.0
+ avg_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0.0
+ avg_ndcg = np.mean(ndcgs) if ndcgs else 0.0
+
+ print(f"\n--- Overall Retrieval Results (k={k}) --- ")
+ print(f"Average Hit@{k}: {avg_hit:.4f}")
+ print(f"Average MRR: {avg_mrr:.4f}") # Corrected spacing for alignment
+ print(f"Average NDCG@{k}: {avg_ndcg:.4f}")
+
+ return {f"Hit@{k}": avg_hit,
+ f"MRR": avg_mrr,
+ f"NDCG@{k}": avg_ndcg}
\ No newline at end of file
diff --git a/src/.ipynb_checkpoints/evaluator-checkpoint.py b/src/.ipynb_checkpoints/evaluator-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6357843b5535e420773f1d0889d6f81406253f4
--- /dev/null
+++ b/src/.ipynb_checkpoints/evaluator-checkpoint.py
@@ -0,0 +1,113 @@
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+class AnswerEvaluator:
+ """Evaluates user answers against expected answers using an LLM."""
+
+ def __init__(self, model_id="meta-llama/Llama-3.2-3B-Instruct"):
+ print(f"Initializing AnswerEvaluator with model: {model_id}")
+ try:
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ print("Set pad_token to eos_token")
+
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.float16,
+ device_map="auto"
+ )
+ self.model.eval()
+ self.device = self.model.device
+ print(f"AnswerEvaluator model loaded successfully on device: {self.device}")
+
+ except Exception as e:
+ print(f"Error initializing AnswerEvaluator model {model_id}: {e}")
+ raise
+
+
+ def evaluate_answer(self, user_answer, expected_answer, clinical_context=None):
+ """
+ Compare user answer to expected answer and provide feedback
+
+ Args:
+ user_answer: Examinee's response
+ expected_answer: Model answer from the dataset
+ clinical_context: Optional clinical context to consider
+
+ Returns:
+ Feedback string
+ """
+ context_str = f"Clinical context: {clinical_context}\n\n" if clinical_context else ""
+
+ prompt = f"""[INST] You are acting as an expert examiner for the American Board of Surgery (ABS) oral board exam. You are evaluating a general surgery resident’s answer to a clinical question. \n
+ Compare the answer provided by the residents to the correct expected answer, which I will provide you with. \n
+ Use the grading rubric below to assess their response:
+
+ [RUBRIC]
+ - Correct: Resident includes all major points and clinical reasoning aligns closely with the expected answer.
+ - Partially Correct: Resident includes some key points but omits others, or reasoning is partially flawed.
+ - Incorrect: Resident misses most key points or demonstrates incorrect reasoning.
+
+ {context_str}Here is the model answer that contains the key points expected from the resident:
+ {expected_answer}
+
+ Now, here is the resident’s actual response:
+ {user_answer}
+
+ Evaluate the resident’s response based **only** on the expected answer above. Do not rely on external knowledge or previous responses.
+
+ Focus your evaluation on:
+ 1. Which key points were mentioned vs. missed
+ 2. The accuracy and clarity of the clinical reasoning
+ 3. Any major omissions or misunderstandings
+
+ Start your output with:
+ ASSESSMENT: [Correct / Partially Correct / Incorrect]
+ Then write 1–2 clear, specific sentences explaining how the resident’s response compares to the expected answer.
+
+ [EXAMPLE 1]
+ Expected answer:
+ "The differential diagnosis includes acute appendicitis, mesenteric adenitis, gastroenteritis, UTI, and testicular torsion."
+
+ Resident’s response:
+ "My top concern is appendicitis, but I’d also consider things like gastroenteritis or maybe even kidney stones."
+
+ ASSESSMENT: Partially Correct
+ The resident mentioned appendicitis and gastroenteritis but missed several other expected differentials like UTI, testicular torsion, and mesenteric adenitis.
+
+ [EXAMPLE 2]
+ Expected answer:
+ "Initial labs should include CBC, CMP, lipase, and abdominal ultrasound to assess for gallstones."
+
+ Resident’s response:
+ "I’d start with a full workup including CBC, liver enzymes, lipase, and an abdominal ultrasound."
+
+ ASSESSMENT: Correct
+ The resident included all key labs and the correct imaging modality. Their reasoning aligns well with the expected answer.
+
+ [/INST]"""
+
+ try:
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device) # Added truncation
+
+ with torch.no_grad():
+ # Generate feedback using the model
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=150,
+ do_sample=True,
+ temperature = 0.2,
+ pad_token_id=self.tokenizer.eos_token_id # Ensure pad token ID is set
+ )
+
+ prompt_length_tokens = inputs.input_ids.shape[1]
+ generated_ids = outputs[0][prompt_length_tokens:]
+
+ feedback = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
+
+ return feedback
+
+ except Exception as e:
+ print(f"Error during LLM evaluation: {e}")
+ return "Error: Could not generate feedback."
\ No newline at end of file
diff --git a/src/.ipynb_checkpoints/retriever-checkpoint.py b/src/.ipynb_checkpoints/retriever-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..a740910e3b534ec99d0be1fa3ceaa917774d8302
--- /dev/null
+++ b/src/.ipynb_checkpoints/retriever-checkpoint.py
@@ -0,0 +1,172 @@
+import numpy as np
+import torch
+from sentence_transformers import SentenceTransformer
+from sklearn.metrics.pairwise import cosine_similarity
+from datasets import load_from_disk, Dataset
+import os
+import pandas as pd
+
+class ClinicalCaseRetriever:
+ """Retrieves relevant clinical cases based on user input using sentence-transformers embeddings."""
+
+ def __init__(self, dataset_path='./processed_clinical_cases', model_name="all-MiniLM-L6-v2"):
+ print(f"Initializing ClinicalCaseRetriever with model: {model_name}")
+ if isinstance(dataset_path, Dataset):
+ self.dataset = dataset_path
+ print("Using provided Hugging Face Dataset object.")
+ elif isinstance(dataset_path, str) and os.path.isdir(dataset_path):
+ try:
+ self.dataset = load_from_disk(dataset_path)
+ print(f"Dataset loaded successfully from disk: {dataset_path}")
+ except Exception as e:
+ print(f"Error loading dataset from disk {dataset_path}: {e}")
+ raise ValueError("Failed to load dataset.") from e
+ else:
+ raise ValueError(f"Invalid dataset_path: Must be a Dataset object or a valid directory path. Got: {dataset_path}")
+
+ if 'embeddings' not in self.dataset.column_names:
+ raise ValueError("Dataset must contain an 'embeddings' column.")
+
+ print(f"Dataset features: {self.dataset.features}")
+ print(f"Number of cases in dataset: {len(self.dataset)}")
+
+
+ self.model = SentenceTransformer(model_name)
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print(f"Using device: {self.device}")
+ self.model.to(self.device)
+
+ try:
+ # Ensure embeddings are loaded as a NumPy array
+ self.case_embeddings = np.array(self.dataset['embeddings'])
+ if self.case_embeddings.ndim != 2:
+ raise ValueError(f"Embeddings array must be 2-dimensional. Got shape: {self.case_embeddings.shape}")
+ print(f"Loaded {len(self.dataset)} cases with embeddings of shape {self.case_embeddings.shape}")
+ except Exception as e:
+ print(f"Error processing embeddings from dataset: {e}")
+ raise ValueError("Failed to load or process embeddings.") from e
+
+
+ def get_available_cases(self, n=5):
+ """Return a sample of available cases for user selection."""
+ num_cases = len(self.dataset)
+ if num_cases == 0:
+ return []
+ sample_size = min(n, num_cases)
+ indices = np.random.choice(num_cases, sample_size, replace=False)
+ # Ensure indices are int for slicing dataset
+ return [(int(i), self.dataset[int(i)]['clinical_presentation']) for i in indices]
+
+ def encode_query(self, query):
+ """Generate embedding for a query string."""
+ # Create a better search query structure
+ search_query = f"Clinical case about {query}"
+ print(f"Encoding query: '{search_query}'")
+
+ # Generate embedding using sentence-transformers
+ try:
+ query_embedding = self.model.encode([search_query], convert_to_numpy=True, device=self.device, show_progress_bar=False)
+ return query_embedding
+ except Exception as e:
+ print(f"Error encoding query '{query}': {e}")
+ return None # Or raise an error
+
+ def retrieve_relevant_case(self, query, top_k=1, return_scores=False):
+ """Find the most relevant clinical case(s) given a query."""
+ if not isinstance(top_k, int) or top_k < 1:
+ print("Warning: top_k must be a positive integer. Defaulting to 1.")
+ top_k = 1
+
+ # Get query embedding
+ query_embedding = self.encode_query(query)
+ if query_embedding is None:
+ return [] if not return_scores else ([], [])
+
+ # Calculate similarity scores
+ try:
+ similarities = cosine_similarity(query_embedding, self.case_embeddings)[0] # Get the single row of similarities
+ except Exception as e:
+ print(f"Error calculating cosine similarity: {e}")
+ return [] if not return_scores else ([], [])
+
+ # Get indices of top-k most similar cases
+ # Ensure we don't request more indices than available cases
+ k_actual = min(top_k, len(similarities))
+ if k_actual == 0: # Should not happen if dataset loaded, but safe check
+ return [] if not return_scores else ([], [])
+
+ # Use partitioning for efficiency if k is much smaller than N, or argsort otherwise
+ # Using argsort is generally simpler and fine for moderate N
+ top_indices = np.argsort(similarities)[-k_actual:][::-1].astype(int) # Get top k indices, sorted descending
+
+ top_scores = similarities[top_indices].tolist() # Get scores for these indices
+
+ # Return the most relevant case(s)
+ try:
+ # Retrieve cases safely using integer indices
+ retrieved_cases = [self.dataset[int(idx)] for idx in top_indices]
+ except IndexError as e:
+ print(f"Error retrieving cases using indices {top_indices}: {e}")
+ return [] if not return_scores else ([], [])
+ except Exception as e:
+ print(f"Unexpected error retrieving cases: {e}")
+ return [] if not return_scores else ([], [])
+
+
+ results_with_scores = list(zip(retrieved_cases, top_scores))
+ print(f"Retrieved {len(results_with_scores)} cases with similarity scores:")
+ for case, score in results_with_scores:
+ # Safely access presentation, provide default if missing
+ presentation = case.get('clinical_presentation', 'Unknown Presentation')
+ print(f"- {presentation}: {score:.4f}")
+
+ if return_scores:
+ return retrieved_cases, top_scores
+ else:
+ # Return list of tuples (case_dict, score)
+ return results_with_scores
+
+
+class DummyRetriever:
+ """A simple retriever that bypasses RAG, taking a pre-formatted DataFrame."""
+
+ def __init__(self, df):
+ self.dataset = []
+ if not isinstance(df, pd.DataFrame) or df.empty:
+ print("Warning: DummyRetriever initialized with empty or invalid DataFrame.")
+ return
+
+ # Expects df to be pre-processed with columns:
+ # 'clinical_presentation', 'turn_id', 'question', 'answer'
+ required_cols = ['clinical_presentation', 'turn_id', 'question', 'answer']
+ if not all(col in df.columns for col in required_cols):
+ print(f"Warning: DummyRetriever DataFrame missing required columns. Need: {required_cols}")
+ return
+
+ grouped = df.groupby('clinical_presentation')
+ print(f"DummyRetriever processing {len(grouped)} unique presentations.")
+ for i, (scenario, group) in enumerate(grouped):
+ group_sorted = group.sort_values('turn_id')
+
+ case_dict = {
+ "case_id": group_sorted['case_id'].iloc[0] if 'case_id' in group_sorted.columns else f"dummy_{i}",
+ "clinical_presentation": scenario,
+ "questions": group_sorted["question"].tolist(),
+ "answers": group_sorted["answer"].tolist()
+ }
+ self.dataset.append(case_dict)
+ print(f"DummyRetriever initialized with {len(self.dataset)} cases.")
+
+ def retrieve_relevant_case(self, scenario_query, top_k=1):
+ """
+ Finds the case matching the query string exactly.
+ Ignores 'top_k' but mimics the return structure [(case_dict, score)].
+ """
+ print(f"DummyRetriever searching for exact match: '{scenario_query}'")
+ for case_dict in self.dataset:
+ if case_dict["clinical_presentation"] == scenario_query:
+ print(f"DummyRetriever found match: {case_dict['clinical_presentation']}")
+ return [(case_dict, 1.0)]
+
+ print(f"DummyRetriever: No exact match found for '{scenario_query}'")
+ return []
\ No newline at end of file
diff --git a/src/.ipynb_checkpoints/simulator-checkpoint.py b/src/.ipynb_checkpoints/simulator-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..f79afe531f267ae7d5d9778534b56bf792825add
--- /dev/null
+++ b/src/.ipynb_checkpoints/simulator-checkpoint.py
@@ -0,0 +1,247 @@
+import json
+import pandas as pd
+# Assuming retriever and evaluator classes are in these files:
+from .retriever import ClinicalCaseRetriever, DummyRetriever
+from .evaluator import AnswerEvaluator
+
+class OralExamSimulator:
+ """Main class that coordinates the oral board exam simulation."""
+
+ def __init__(self, retriever, evaluator):
+ if not isinstance(retriever, (ClinicalCaseRetriever, DummyRetriever)):
+ raise TypeError("Retriever must be an instance of ClinicalCaseRetriever or DummyRetriever")
+ if not isinstance(evaluator, AnswerEvaluator):
+ raise TypeError("Evaluator must be an instance of AnswerEvaluator")
+
+ self.retriever = retriever
+ self.evaluator = evaluator
+ self.current_case = None
+ self.current_question_idx = 0
+ self.session_history = []
+
+ def start_new_case(self, clinical_query=None, case_idx=None):
+ """
+ Initialize a new exam case based on query or direct selection.
+
+ Args:
+ clinical_query (str, optional): Text description of the desired case topic.
+ case_idx (int, optional): Direct index of the case to use from the retriever's dataset.
+
+ Returns:
+ dict: Contains case info and first question, or an error message.
+ """
+ print("-" * 50)
+ print(f"Attempting to start new case | Query: '{clinical_query}' | Index: {case_idx}")
+
+ # Reset state for the new case
+ self.current_case = None
+ self.current_question_idx = 0
+ self.session_history = []
+
+ # Case selection logic
+ retrieved_info = None # Use a temporary variable
+ if case_idx is not None:
+ try:
+ # Direct case selection by index
+ # Ensure index is valid
+ if 0 <= int(case_idx) < len(self.retriever.dataset):
+ self.current_case = self.retriever.dataset[int(case_idx)]
+ similarity_score = 1.0 # Direct selection implies perfect 'match'
+ print(f"Selected case by index {case_idx}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
+ retrieved_info = (self.current_case, similarity_score)
+ else:
+ print(f"Error: Invalid case index {case_idx}. Must be between 0 and {len(self.retriever.dataset)-1}.")
+ return {"error": f"Invalid case index: {case_idx}"}
+ except Exception as e:
+ print(f"Error selecting case by index {case_idx}: {e}")
+ return {"error": f"Failed to select case by index: {e}"}
+
+ elif clinical_query:
+ # RAG-based retrieval
+ try:
+ # retrieve_relevant_case now returns a list of tuples: [(case_dict, score), ...]
+ retrieved_results = self.retriever.retrieve_relevant_case(clinical_query, top_k=1)
+ if retrieved_results: # Check if list is not empty
+ retrieved_info = retrieved_results[0] # Get the first tuple (case_dict, score)
+ self.current_case = retrieved_info[0]
+ similarity_score = retrieved_info[1]
+ print(f"Retrieved case via query ('{clinical_query}') with score {similarity_score:.4f}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
+ else:
+ print(f"Error: No case found for query: '{clinical_query}'")
+ return {"error": f"No relevant case found for query: {clinical_query}"}
+ except Exception as e:
+ print(f"Error retrieving case for query '{clinical_query}': {e}")
+ return {"error": f"Failed to retrieve case by query: {e}"}
+ else:
+ # No selection method provided
+ print("Error: Must provide either a clinical query or a case index.")
+ return {"error": "Please provide either a clinical query or case index."}
+
+ # --- Post-selection setup ---
+ if self.current_case is None:
+ # This should ideally be caught above, but double-check
+ print("Error: Failed to set current_case.")
+ return {"error": "Failed to load the selected case."}
+
+ # Validate case structure
+ if 'questions' not in self.current_case or 'answers' not in self.current_case or \
+ not isinstance(self.current_case['questions'], list) or \
+ not isinstance(self.current_case['answers'], list) or \
+ len(self.current_case['questions']) != len(self.current_case['answers']):
+ print(f"Error: Invalid case structure for case ID {self.current_case.get('case_id', 'N/A')}. Mismatched or missing Q/A lists.")
+ return {"error": "Selected case has invalid format."}
+
+ if not self.current_case['questions']:
+ print(f"Warning: Selected case ID {self.current_case.get('case_id', 'N/A')} has no questions.")
+ # Decide how to handle this - error or proceed? Let's return an error for now.
+ return {"error": "Selected case contains no questions."}
+
+
+ # Start a new session record
+ self.session_history.append({
+ "role": "system",
+ "content": f"Clinical scenario started: {self.current_case.get('clinical_presentation', 'Unknown Presentation')} (Case ID: {self.current_case.get('case_id', 'N/A')})"
+ })
+
+ # Get the first question
+ first_question = self.current_case['questions'][0]
+
+ # Record this interaction
+ self.session_history.append({
+ "role": "examiner",
+ "content": first_question
+ })
+
+ print(f"Case successfully started. Total questions: {len(self.current_case['questions'])}")
+ print("-" * 50)
+
+ return {
+ "case_id": self.current_case.get('case_id', 'unknown'),
+ "clinical_presentation": self.current_case.get('clinical_presentation', 'Unknown'),
+ "similarity_score": similarity_score, # Use the score from retrieval
+ "current_question": first_question,
+ "question_number": 1,
+ "total_questions": len(self.current_case['questions'])
+ }
+
+ def process_user_response(self, response):
+ """
+ Process the user's answer, get feedback, and return the next question or completion status.
+
+ Args:
+ response (str): User's answer text.
+
+ Returns:
+ dict: Contains feedback, expected answer, completion status, and next question (if applicable), or an error message.
+ """
+ if self.current_case is None:
+ print("Error: No active case.")
+ return {"error": "No active case. Please start a new case first."}
+
+ if self.current_question_idx >= len(self.current_case['questions']):
+ print("Error: Attempting to process response when case is already complete.")
+ return {"error": "Case already completed."}
+
+ print("-" * 50)
+ current_q_num = self.current_question_idx + 1
+ total_q = len(self.current_case['questions'])
+ print(f"Processing response for Question {current_q_num}/{total_q}")
+ print(f"User Response: {response}")
+
+ # Save the user's response to history
+ self.session_history.append({
+ "role": "resident",
+ "content": response
+ })
+
+ # Get the expected answer for the current question
+ expected_answer = self.current_case['answers'][self.current_question_idx]
+ print(f"Expected Answer: {expected_answer}")
+
+ # Evaluate the answer
+ feedback = self.evaluator.evaluate_answer(
+ response,
+ expected_answer,
+ clinical_context = f"Regarding the case '{self.current_case.get('clinical_presentation', 'N/A')}'"
+ )
+ print(f"Generated Feedback: {feedback}")
+
+
+ # Add feedback to history
+ self.session_history.append({
+ "role": "feedback",
+ "content": feedback
+ })
+
+ # Move to the next question index
+ self.current_question_idx += 1
+
+ # Check if the case is complete
+ is_complete = self.current_question_idx >= len(self.current_case['questions'])
+
+ result = {
+ "feedback": feedback,
+ "expected_answer": expected_answer,
+ "is_complete": is_complete,
+ "question_number": self.current_question_idx
+ }
+
+ # Add next question if not complete
+ if not is_complete:
+ next_question = self.current_case['questions'][self.current_question_idx]
+ result["next_question"] = next_question
+ result["total_questions"] = total_q
+
+ # Add next question to history
+ self.session_history.append({
+ "role": "examiner",
+ "content": next_question
+ })
+ print(f"Next question ({result['question_number']}/{total_q}): {next_question}")
+ else:
+ print("Case completed.")
+ summary = self.generate_session_summary()
+ result["session_summary"] = summary
+ self.session_history.append({
+ "role": "system",
+ "content": "End of clinical scenario."
+ })
+
+
+ print("-" * 50)
+ return result
+
+ def generate_session_summary(self):
+ """Generate a summary dictionary of the completed session."""
+ if not self.current_case or not self.session_history:
+ return {"error": "No active or completed session to summarize."}
+
+ # Simple summary structure
+ return {
+ "case_id": self.current_case.get('case_id', 'N/A'),
+ "case": self.current_case.get('clinical_presentation', 'Unknown'),
+ "total_questions_in_case": len(self.current_case.get('questions', [])),
+ "interaction_history": self.session_history # Include the full log
+ }
+
+ def save_session(self, filepath):
+ """Save the current session summary to a JSON file."""
+ summary = self.generate_session_summary()
+ if "error" in summary:
+ print(f"Error generating summary for saving: {summary['error']}")
+ return {"error": "No session to save"}
+
+ try:
+ # Add a timestamp to the saved data
+ summary["timestamp"] = pd.Timestamp.now().isoformat()
+
+ # Ensure directory exists
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+
+ with open(filepath, 'w') as f:
+ json.dump(summary, f, indent=2)
+ print(f"Session saved successfully to {filepath}")
+ return {"status": "Session saved successfully"}
+ except Exception as e:
+ print(f"Error saving session to {filepath}: {e}")
+ return {"error": f"Failed to save session: {e}"}
\ No newline at end of file
diff --git a/src/.ipynb_checkpoints/synthetic_generator-checkpoint.py b/src/.ipynb_checkpoints/synthetic_generator-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..409efe96efb2c47ce08a52fa0215ef0be6df0083
--- /dev/null
+++ b/src/.ipynb_checkpoints/synthetic_generator-checkpoint.py
@@ -0,0 +1,125 @@
+import pandas as pd
+import numpy as np
+import torch
+import re
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+def generate_synthetic_case(clinical_query, model_id="meta-llama/Llama-3.2-3B-Instruct", max_tokens=800):
+ """Generate a synthetic clinical case with examiner questions and expected answers."""
+ print(f"Generating synthetic case for '{clinical_query}' using {model_id}...")
+ gen_tokenizer = None
+ gen_model = None
+ try:
+ # Initialize generator model components
+ gen_tokenizer = AutoTokenizer.from_pretrained(model_id)
+ gen_model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.float16,
+ device_map="auto"
+ )
+ gen_model.eval()
+ device = gen_model.device
+ if gen_tokenizer.pad_token is None:
+ gen_tokenizer.pad_token = gen_tokenizer.eos_token
+
+ except Exception as e:
+ print(f"Error initializing generator model {model_id}: {e}")
+ return None
+
+ prompt = f"""[INST] You are a board-certified general surgeon simulating a clinical oral board exam.
+ Create a synthetic case on the topic: "{clinical_query}".
+ Start by describing the initial clinical presentation in 1–2 sentences.
+ Then generate a list of 5–8 examiner questions (Q1, Q2...), each paired with the expected examinee answer (A1, A2...). Ensure Q/A pairs are clearly separated.
+ Output ONLY the presentation and Q&A pairs in this exact format:
+ Clinical Presentation: ...
+
+ Q1: ...
+ A1: ...
+
+ Q2: ...
+ A2: ...
+
+ (continue until Qn/An)
+ Focus on common scenarios and standard knowledge. Avoid overly complex or rare details.
+ [/INST]"""
+
+ output_text = None
+ try:
+ inputs = gen_tokenizer(prompt, return_tensors="pt").to(device)
+ input_ids_length = inputs.input_ids.shape[1]
+
+ with torch.no_grad():
+ outputs = gen_model.generate(
+ inputs.input_ids,
+ max_new_tokens=max_tokens,
+ do_sample=True, # Sample to get potentially varied outputs
+ temperature=0.7,
+ top_p=0.9,
+ pad_token_id=gen_tokenizer.eos_token_id
+ )
+
+ # Decode only the newly generated tokens
+ generated_ids = outputs[0][input_ids_length:]
+ output_text = gen_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
+ print("Synthetic case generation complete.")
+
+ except Exception as e:
+ print(f"Error during synthetic case generation: {e}")
+ finally:
+ # Clean up model resources
+ del gen_model
+ del gen_tokenizer
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ return output_text
+
+def process_synthetic_data(clinical_query, output_text):
+ """Process the raw LLM output text into a structured DataFrame for the DummyRetriever."""
+ # Extract clinical presentation
+ match = re.search(r"Clinical Presentation:(.*?)(?=\n\nQ1:|$)", output_text, re.DOTALL | re.IGNORECASE)
+ clinical_presentation_text = match.group(1).strip() if match else "Synthetic Case: " + clinical_query
+
+ # Extract Q&A pairs
+ qa_pattern = r"Q(\d+):\s*(.*?)\s*A\1:\s*(.*?)(?=\n*Q\d+:|\Z)"
+ qa_matches = re.findall(qa_pattern, output_text, flags=re.DOTALL | re.IGNORECASE)
+
+ qa_list = []
+ for match_tuple in qa_matches:
+ try:
+ q_num = int(match_tuple[0])
+ q_text = match_tuple[1].strip()
+ a_text = match_tuple[2].strip()
+ if q_text and a_text:
+ qa_list.append({'turn_id': q_num, 'question': q_text, 'answer': a_text})
+ except (IndexError, ValueError) as e:
+ print(f"Warning: Skipping malformed Q/A match: {match_tuple} due to {e}")
+
+ if not qa_list:
+ print("Warning: No valid Q&A pairs extracted from synthetic text.")
+ return pd.DataFrame()
+
+ qa_list.sort(key=lambda item: item['turn_id'])
+
+ rows = []
+ for item in qa_list:
+ rows.append({
+ 'case_id': 'SYNTH_01',
+ 'clinical_presentation': clinical_query, # Use query as presentation title
+ 'turn_id': item['turn_id'],
+ 'question': item['question'],
+ 'answer': item['answer']
+ })
+
+ df_synthetic = pd.DataFrame(rows)
+
+ if not df_synthetic.empty and clinical_presentation_text:
+ # Find the index of the first turn
+ first_turn_index = df_synthetic[df_synthetic['turn_id'] == 1].index
+ if not first_turn_index.empty:
+ idx = first_turn_index[0]
+ df_synthetic.loc[idx, 'question'] = clinical_presentation_text + " " + df_synthetic.loc[idx, 'question']
+ else:
+ print("Warning: Could not find turn_id 1 to prepend presentation.")
+
+ print(f"Processed synthetic data into DataFrame with {len(df_synthetic)} turns.")
+ return df_synthetic
\ No newline at end of file
diff --git a/src/__pycache__/data_processing.cpython-311.pyc b/src/__pycache__/data_processing.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02bac197d0e760ca02cbb73e75f9a4146a5f1d9b
Binary files /dev/null and b/src/__pycache__/data_processing.cpython-311.pyc differ
diff --git a/src/__pycache__/evaluation_utils.cpython-311.pyc b/src/__pycache__/evaluation_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9e0a076bbeb1fae6d6f9302182dc018fafddc3c
Binary files /dev/null and b/src/__pycache__/evaluation_utils.cpython-311.pyc differ
diff --git a/src/__pycache__/evaluator.cpython-311.pyc b/src/__pycache__/evaluator.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed6d36b41e6823aabd8088dbe64f456ed374286d
Binary files /dev/null and b/src/__pycache__/evaluator.cpython-311.pyc differ
diff --git a/src/__pycache__/retriever.cpython-311.pyc b/src/__pycache__/retriever.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24d57dcd55d48099fbd3a46f010f6abaafb0b7bf
Binary files /dev/null and b/src/__pycache__/retriever.cpython-311.pyc differ
diff --git a/src/__pycache__/simulator.cpython-311.pyc b/src/__pycache__/simulator.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a89f5e9bfe17b76d308550b8f6161f169f707a4
Binary files /dev/null and b/src/__pycache__/simulator.cpython-311.pyc differ
diff --git a/src/__pycache__/synthetic_generator.cpython-311.pyc b/src/__pycache__/synthetic_generator.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..579c23cabeeb6d397d65786557a6b1cbd5ae1c7a
Binary files /dev/null and b/src/__pycache__/synthetic_generator.cpython-311.pyc differ
diff --git a/src/data_processing.py b/src/data_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..136d75448f617a50e50ff9de2bc72a915e5f6634
--- /dev/null
+++ b/src/data_processing.py
@@ -0,0 +1,201 @@
+import os
+import re
+import pandas as pd
+import numpy as np
+import torch
+from docx import Document
+from sentence_transformers import SentenceTransformer
+from datasets import Dataset
+from tqdm import tqdm
+
+def read_docx(file_path):
+ """Reads text content from a .docx file."""
+ try:
+ doc = Document(file_path)
+ return '\n'.join(para.text for para in doc.paragraphs)
+ except Exception as e:
+ print(f"Error reading {file_path}: {e}")
+ return ""
+
+def extract_qa_pairs(text):
+ """Extracts alternating Examiner and Examinee Q&A pairs from text."""
+ pattern = re.compile(r"\*\*Examiner:\*\*(.*?)\n\n\*\*Examinee:\*\*(.*?)(?=\n\n\*\*Examiner:\*\*|$)", re.DOTALL)
+ return [{"question": q.strip(), "answer": a.strip()} for q, a in pattern.findall(text)]
+
+def parse_filename(filename):
+ """Parses case ID and topic from BTK filename format."""
+ # Example: BTK_-_77A___Burn.docx -> case_id = 77A, clinical_presentation = Burn
+ base = os.path.splitext(filename)[0]
+ match = re.match(r"BTK_-_(\d+[A-Z]?)___(.+)", base)
+ if match:
+ case_id = match.group(1)
+ topic = match.group(2).replace("_", " ").strip()
+ else:
+ # Handle potential variations or log unknown formats if needed
+ print(f"Warning: Could not parse filename format: {filename}")
+ case_id, topic = "Unknown", "Unknown"
+ return case_id, topic
+
+def process_all_cases(folder_path):
+ """Reads all .docx files in a folder and structures them into a DataFrame."""
+ rows = []
+ if not os.path.isdir(folder_path):
+ print(f"Error: Folder not found at {folder_path}")
+ return pd.DataFrame(rows)
+
+ print(f"Processing case files from: {folder_path}")
+ for filename in os.listdir(folder_path):
+ if filename.lower().endswith('.docx') and not filename.startswith('~'): # Avoid temp files
+ file_path = os.path.join(folder_path, filename)
+ text = read_docx(file_path)
+ if text:
+ qa_pairs = extract_qa_pairs(text)
+ case_id, presentation = parse_filename(filename)
+ if not qa_pairs:
+ print(f"Warning: No Q&A pairs extracted from {filename}")
+ for i, pair in enumerate(qa_pairs):
+ rows.append({
+ "case_id": case_id,
+ "clinical_presentation": presentation,
+ "turn_id": i + 1,
+ "question": pair["question"],
+ "answer": pair["answer"]
+ })
+ else:
+ print(f"Warning: Empty content for file {filename}")
+
+ if not rows:
+ print("Warning: No data rows were generated. Check input files and formats.")
+
+ return pd.DataFrame(rows)
+
+
+# --- ClinicalCaseProcessor Class ---
+class ClinicalCaseProcessor:
+ """Handles preprocessing of clinical cases for the RAG system using sentence-transformers."""
+ def __init__(self, model_name="all-MiniLM-L6-v2"):
+ print(f"Initializing ClinicalCaseProcessor with model: {model_name}")
+ self.model = SentenceTransformer(model_name)
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print(f"Using device: {self.device}")
+ self.model.to(self.device)
+
+ def preprocess_data(self, input_data, output_path="./processed_clinical_cases", batch_size=16):
+ """
+ Convert raw case data (DataFrame or path to CSV) into a vectorized Hugging Face dataset.
+
+ Args:
+ input_data: DataFrame or path to CSV file with clinical cases.
+ output_path: Where to save the processed Hugging Face dataset.
+ batch_size: Batch size for embedding generation.
+
+ Returns:
+ datasets.Dataset: The processed dataset with embeddings.
+ """
+ # Load data
+ if isinstance(input_data, pd.DataFrame):
+ df = input_data
+ print("Using provided DataFrame.")
+ elif isinstance(input_data, str) and os.path.exists(input_data):
+ try:
+ df = pd.read_csv(input_data)
+ print(f"Data loaded from CSV: {input_data}")
+ except Exception as e:
+ print(f"Error loading CSV {input_data}: {e}")
+ return None
+ else:
+ print(f"Error: Invalid input_data type or path does not exist: {input_data}")
+ return None
+
+ if df.empty:
+ print("Error: Input DataFrame is empty. Cannot process.")
+ return None
+
+ print(f"Raw data shape: {df.shape}")
+
+ # Validate necessary columns
+ required_cols = ['case_id', 'clinical_presentation', 'turn_id', 'question', 'answer']
+ if not all(col in df.columns for col in required_cols):
+ print(f"Error: DataFrame missing required columns. Found: {df.columns}. Required: {required_cols}")
+ return None
+
+ # Group by case_id to get all Q&A pairs for each case
+ grouped = df.groupby(['case_id', 'clinical_presentation'], dropna=False)
+
+ # Create a new dataframe with one row per case
+ case_data = []
+ print("Grouping data by case...")
+ for (case_id, presentation), group in tqdm(grouped, desc="Processing Cases"):
+ # Sort by turn_id to ensure correct order
+ group = group.sort_values('turn_id')
+
+ # Extract questions and answers in order
+ questions = group['question'].tolist()
+ answers = group['answer'].tolist()
+
+ # Handle potential NaN/None in presentation if groupby didn't drop them
+ presentation_str = str(presentation) if pd.notna(presentation) else "Unknown Presentation"
+
+ case_data.append({
+ 'case_id': str(case_id) if pd.notna(case_id) else "Unknown ID",
+ 'clinical_presentation': presentation_str,
+ 'questions': questions,
+ 'answers': answers
+ })
+
+ if not case_data:
+ print("Error: No cases could be processed after grouping. Check input data integrity.")
+ return None
+
+ processed_df = pd.DataFrame(case_data)
+ print(f"Processed data into {len(processed_df)} unique cases.")
+
+ # Create a searchable summary of each case (handle empty question lists)
+ processed_df['case_summary'] = processed_df.apply(
+ lambda x: f"Clinical case: {x['clinical_presentation']}. First question: {x['questions'][0] if x['questions'] else 'No questions available'}",
+ axis=1
+ )
+
+ # Generate embeddings using sentence-transformers
+ texts_to_embed = processed_df['case_summary'].tolist()
+ all_embeddings = []
+
+ print(f"Generating embeddings for {len(texts_to_embed)} case summaries...")
+ try:
+ for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Batches"):
+ batch_texts = texts_to_embed[i:i+batch_size]
+ # Generate embeddings for the batch
+ batch_embeddings = self.model.encode(batch_texts, convert_to_numpy=True, device=self.device, show_progress_bar=False)
+ all_embeddings.append(batch_embeddings)
+
+ # Combine all batch embeddings
+ if not all_embeddings:
+ print("Error: No embeddings were generated.")
+ return None
+ final_embeddings = np.vstack(all_embeddings)
+ print(f"Generated embeddings with shape: {final_embeddings.shape}")
+
+ except Exception as e:
+ print(f"Error during embedding generation: {e}")
+ return None
+
+
+ # Convert to HF Dataset and add embeddings
+ try:
+ dataset = Dataset.from_pandas(processed_df)
+ # Ensure embeddings column is compatible (list of lists)
+ dataset = dataset.add_column('embeddings', final_embeddings.tolist())
+ except Exception as e:
+ print(f"Error converting to Hugging Face Dataset or adding embeddings: {e}")
+ return None
+
+ # Save processed dataset
+ try:
+ os.makedirs(output_path, exist_ok=True) # Ensure directory exists
+ dataset.save_to_disk(output_path)
+ print(f"Processed dataset saved successfully to {output_path}")
+ except Exception as e:
+ print(f"Error saving dataset to disk at {output_path}: {e}")
+ return None # Return None if saving failed
+
+ return dataset
\ No newline at end of file
diff --git a/src/evaluation_utils.py b/src/evaluation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1463078095037aa4008b72fd68a57951ff57f3cf
--- /dev/null
+++ b/src/evaluation_utils.py
@@ -0,0 +1,67 @@
+import numpy as np
+from sklearn.metrics import ndcg_score
+from src.retriever import ClinicalCaseRetriever, DummyRetriever
+
+def retrieval_metrics(retriever_instance: ClinicalCaseRetriever, queries: list[str], gold_ids: list[str], k: int = 5) -> dict | None:
+ """
+ Calculates retrieval metrics for a set of queries.
+
+ Args:
+ retriever_instance: An initialized ClinicalCaseRetriever instance.
+ queries: A list of query strings.
+ gold_ids: A list of the expected 'case_id' strings for each query.
+ k: The number of top results to consider for Hit@k and NDCG@k.
+
+ Returns:
+ A dictionary containing Hit@k, MRR, and NDCG@k scores, or None on error.
+ """
+ # --- Initialization ---
+ hits, reciprocal_ranks, ndcgs = [], [], []
+ print(f"\nCalculating retrieval metrics for {len(queries)} queries (k={k})...")
+
+ # --- Process Each Query ---
+ for q_idx, (q, gold) in enumerate(zip(queries, gold_ids)):
+ print(f"\nProcessing query {q_idx+1}/{len(queries)}: '{q}' (Expected ID: '{gold}')")
+ retrieved_cases, scores = retriever_instance.retrieve_relevant_case(q, top_k=k, return_scores=True)
+
+ # Safely extract IDs, handle missing keys
+ retrieved_ids = [c.get('case_id', 'N/A') for c in retrieved_cases]
+ print(f"Retrieved IDs: {retrieved_ids}")
+ print(f"Retrieved Scores: {[round(s, 4) for s in scores]}")
+
+ # --- Calculate Metrics ---
+ is_hit = int(gold in retrieved_ids)
+ hits.append(is_hit)
+
+ rank = 0
+ if is_hit:
+ rank = retrieved_ids.index(gold) + 1
+ reciprocal_ranks.append(1.0 / rank)
+ else:
+ reciprocal_ranks.append(0.0)
+
+ # NDCG calculation
+ true_relevance = np.asarray([[1.0 if gid == gold else 0.0 for gid in retrieved_ids]])
+ predicted_scores = np.asarray([scores])
+
+ current_ndcg = 0.0
+ if true_relevance.shape[1] > 0:
+ ndcg_k = min(k, true_relevance.shape[1]) # Ensure k is not out of bounds
+ current_ndcg = ndcg_score(true_relevance, predicted_scores, k=ndcg_k)
+ ndcgs.append(current_ndcg)
+
+ print(f"Hit: {is_hit}, Rank: {rank if rank > 0 else 'N/A'}, NDCG@{k}: {current_ndcg:.4f}")
+
+ # --- Aggregate Results ---
+ avg_hit = np.mean(hits) if hits else 0.0
+ avg_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0.0
+ avg_ndcg = np.mean(ndcgs) if ndcgs else 0.0
+
+ print(f"\n--- Overall Retrieval Results (k={k}) --- ")
+ print(f"Average Hit@{k}: {avg_hit:.4f}")
+ print(f"Average MRR: {avg_mrr:.4f}") # Corrected spacing for alignment
+ print(f"Average NDCG@{k}: {avg_ndcg:.4f}")
+
+ return {f"Hit@{k}": avg_hit,
+ f"MRR": avg_mrr,
+ f"NDCG@{k}": avg_ndcg}
\ No newline at end of file
diff --git a/src/evaluator.py b/src/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6357843b5535e420773f1d0889d6f81406253f4
--- /dev/null
+++ b/src/evaluator.py
@@ -0,0 +1,113 @@
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+class AnswerEvaluator:
+ """Evaluates user answers against expected answers using an LLM."""
+
+ def __init__(self, model_id="meta-llama/Llama-3.2-3B-Instruct"):
+ print(f"Initializing AnswerEvaluator with model: {model_id}")
+ try:
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ print("Set pad_token to eos_token")
+
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.float16,
+ device_map="auto"
+ )
+ self.model.eval()
+ self.device = self.model.device
+ print(f"AnswerEvaluator model loaded successfully on device: {self.device}")
+
+ except Exception as e:
+ print(f"Error initializing AnswerEvaluator model {model_id}: {e}")
+ raise
+
+
+ def evaluate_answer(self, user_answer, expected_answer, clinical_context=None):
+ """
+ Compare user answer to expected answer and provide feedback
+
+ Args:
+ user_answer: Examinee's response
+ expected_answer: Model answer from the dataset
+ clinical_context: Optional clinical context to consider
+
+ Returns:
+ Feedback string
+ """
+ context_str = f"Clinical context: {clinical_context}\n\n" if clinical_context else ""
+
+ prompt = f"""[INST] You are acting as an expert examiner for the American Board of Surgery (ABS) oral board exam. You are evaluating a general surgery resident’s answer to a clinical question. \n
+ Compare the answer provided by the residents to the correct expected answer, which I will provide you with. \n
+ Use the grading rubric below to assess their response:
+
+ [RUBRIC]
+ - Correct: Resident includes all major points and clinical reasoning aligns closely with the expected answer.
+ - Partially Correct: Resident includes some key points but omits others, or reasoning is partially flawed.
+ - Incorrect: Resident misses most key points or demonstrates incorrect reasoning.
+
+ {context_str}Here is the model answer that contains the key points expected from the resident:
+ {expected_answer}
+
+ Now, here is the resident’s actual response:
+ {user_answer}
+
+ Evaluate the resident’s response based **only** on the expected answer above. Do not rely on external knowledge or previous responses.
+
+ Focus your evaluation on:
+ 1. Which key points were mentioned vs. missed
+ 2. The accuracy and clarity of the clinical reasoning
+ 3. Any major omissions or misunderstandings
+
+ Start your output with:
+ ASSESSMENT: [Correct / Partially Correct / Incorrect]
+ Then write 1–2 clear, specific sentences explaining how the resident’s response compares to the expected answer.
+
+ [EXAMPLE 1]
+ Expected answer:
+ "The differential diagnosis includes acute appendicitis, mesenteric adenitis, gastroenteritis, UTI, and testicular torsion."
+
+ Resident’s response:
+ "My top concern is appendicitis, but I’d also consider things like gastroenteritis or maybe even kidney stones."
+
+ ASSESSMENT: Partially Correct
+ The resident mentioned appendicitis and gastroenteritis but missed several other expected differentials like UTI, testicular torsion, and mesenteric adenitis.
+
+ [EXAMPLE 2]
+ Expected answer:
+ "Initial labs should include CBC, CMP, lipase, and abdominal ultrasound to assess for gallstones."
+
+ Resident’s response:
+ "I’d start with a full workup including CBC, liver enzymes, lipase, and an abdominal ultrasound."
+
+ ASSESSMENT: Correct
+ The resident included all key labs and the correct imaging modality. Their reasoning aligns well with the expected answer.
+
+ [/INST]"""
+
+ try:
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device) # Added truncation
+
+ with torch.no_grad():
+ # Generate feedback using the model
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=150,
+ do_sample=True,
+ temperature = 0.2,
+ pad_token_id=self.tokenizer.eos_token_id # Ensure pad token ID is set
+ )
+
+ prompt_length_tokens = inputs.input_ids.shape[1]
+ generated_ids = outputs[0][prompt_length_tokens:]
+
+ feedback = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
+
+ return feedback
+
+ except Exception as e:
+ print(f"Error during LLM evaluation: {e}")
+ return "Error: Could not generate feedback."
\ No newline at end of file
diff --git a/src/retriever.py b/src/retriever.py
new file mode 100644
index 0000000000000000000000000000000000000000..a740910e3b534ec99d0be1fa3ceaa917774d8302
--- /dev/null
+++ b/src/retriever.py
@@ -0,0 +1,172 @@
+import numpy as np
+import torch
+from sentence_transformers import SentenceTransformer
+from sklearn.metrics.pairwise import cosine_similarity
+from datasets import load_from_disk, Dataset
+import os
+import pandas as pd
+
+class ClinicalCaseRetriever:
+ """Retrieves relevant clinical cases based on user input using sentence-transformers embeddings."""
+
+ def __init__(self, dataset_path='./processed_clinical_cases', model_name="all-MiniLM-L6-v2"):
+ print(f"Initializing ClinicalCaseRetriever with model: {model_name}")
+ if isinstance(dataset_path, Dataset):
+ self.dataset = dataset_path
+ print("Using provided Hugging Face Dataset object.")
+ elif isinstance(dataset_path, str) and os.path.isdir(dataset_path):
+ try:
+ self.dataset = load_from_disk(dataset_path)
+ print(f"Dataset loaded successfully from disk: {dataset_path}")
+ except Exception as e:
+ print(f"Error loading dataset from disk {dataset_path}: {e}")
+ raise ValueError("Failed to load dataset.") from e
+ else:
+ raise ValueError(f"Invalid dataset_path: Must be a Dataset object or a valid directory path. Got: {dataset_path}")
+
+ if 'embeddings' not in self.dataset.column_names:
+ raise ValueError("Dataset must contain an 'embeddings' column.")
+
+ print(f"Dataset features: {self.dataset.features}")
+ print(f"Number of cases in dataset: {len(self.dataset)}")
+
+
+ self.model = SentenceTransformer(model_name)
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print(f"Using device: {self.device}")
+ self.model.to(self.device)
+
+ try:
+ # Ensure embeddings are loaded as a NumPy array
+ self.case_embeddings = np.array(self.dataset['embeddings'])
+ if self.case_embeddings.ndim != 2:
+ raise ValueError(f"Embeddings array must be 2-dimensional. Got shape: {self.case_embeddings.shape}")
+ print(f"Loaded {len(self.dataset)} cases with embeddings of shape {self.case_embeddings.shape}")
+ except Exception as e:
+ print(f"Error processing embeddings from dataset: {e}")
+ raise ValueError("Failed to load or process embeddings.") from e
+
+
+ def get_available_cases(self, n=5):
+ """Return a sample of available cases for user selection."""
+ num_cases = len(self.dataset)
+ if num_cases == 0:
+ return []
+ sample_size = min(n, num_cases)
+ indices = np.random.choice(num_cases, sample_size, replace=False)
+ # Ensure indices are int for slicing dataset
+ return [(int(i), self.dataset[int(i)]['clinical_presentation']) for i in indices]
+
+ def encode_query(self, query):
+ """Generate embedding for a query string."""
+ # Create a better search query structure
+ search_query = f"Clinical case about {query}"
+ print(f"Encoding query: '{search_query}'")
+
+ # Generate embedding using sentence-transformers
+ try:
+ query_embedding = self.model.encode([search_query], convert_to_numpy=True, device=self.device, show_progress_bar=False)
+ return query_embedding
+ except Exception as e:
+ print(f"Error encoding query '{query}': {e}")
+ return None # Or raise an error
+
+ def retrieve_relevant_case(self, query, top_k=1, return_scores=False):
+ """Find the most relevant clinical case(s) given a query."""
+ if not isinstance(top_k, int) or top_k < 1:
+ print("Warning: top_k must be a positive integer. Defaulting to 1.")
+ top_k = 1
+
+ # Get query embedding
+ query_embedding = self.encode_query(query)
+ if query_embedding is None:
+ return [] if not return_scores else ([], [])
+
+ # Calculate similarity scores
+ try:
+ similarities = cosine_similarity(query_embedding, self.case_embeddings)[0] # Get the single row of similarities
+ except Exception as e:
+ print(f"Error calculating cosine similarity: {e}")
+ return [] if not return_scores else ([], [])
+
+ # Get indices of top-k most similar cases
+ # Ensure we don't request more indices than available cases
+ k_actual = min(top_k, len(similarities))
+ if k_actual == 0: # Should not happen if dataset loaded, but safe check
+ return [] if not return_scores else ([], [])
+
+ # Use partitioning for efficiency if k is much smaller than N, or argsort otherwise
+ # Using argsort is generally simpler and fine for moderate N
+ top_indices = np.argsort(similarities)[-k_actual:][::-1].astype(int) # Get top k indices, sorted descending
+
+ top_scores = similarities[top_indices].tolist() # Get scores for these indices
+
+ # Return the most relevant case(s)
+ try:
+ # Retrieve cases safely using integer indices
+ retrieved_cases = [self.dataset[int(idx)] for idx in top_indices]
+ except IndexError as e:
+ print(f"Error retrieving cases using indices {top_indices}: {e}")
+ return [] if not return_scores else ([], [])
+ except Exception as e:
+ print(f"Unexpected error retrieving cases: {e}")
+ return [] if not return_scores else ([], [])
+
+
+ results_with_scores = list(zip(retrieved_cases, top_scores))
+ print(f"Retrieved {len(results_with_scores)} cases with similarity scores:")
+ for case, score in results_with_scores:
+ # Safely access presentation, provide default if missing
+ presentation = case.get('clinical_presentation', 'Unknown Presentation')
+ print(f"- {presentation}: {score:.4f}")
+
+ if return_scores:
+ return retrieved_cases, top_scores
+ else:
+ # Return list of tuples (case_dict, score)
+ return results_with_scores
+
+
+class DummyRetriever:
+ """A simple retriever that bypasses RAG, taking a pre-formatted DataFrame."""
+
+ def __init__(self, df):
+ self.dataset = []
+ if not isinstance(df, pd.DataFrame) or df.empty:
+ print("Warning: DummyRetriever initialized with empty or invalid DataFrame.")
+ return
+
+ # Expects df to be pre-processed with columns:
+ # 'clinical_presentation', 'turn_id', 'question', 'answer'
+ required_cols = ['clinical_presentation', 'turn_id', 'question', 'answer']
+ if not all(col in df.columns for col in required_cols):
+ print(f"Warning: DummyRetriever DataFrame missing required columns. Need: {required_cols}")
+ return
+
+ grouped = df.groupby('clinical_presentation')
+ print(f"DummyRetriever processing {len(grouped)} unique presentations.")
+ for i, (scenario, group) in enumerate(grouped):
+ group_sorted = group.sort_values('turn_id')
+
+ case_dict = {
+ "case_id": group_sorted['case_id'].iloc[0] if 'case_id' in group_sorted.columns else f"dummy_{i}",
+ "clinical_presentation": scenario,
+ "questions": group_sorted["question"].tolist(),
+ "answers": group_sorted["answer"].tolist()
+ }
+ self.dataset.append(case_dict)
+ print(f"DummyRetriever initialized with {len(self.dataset)} cases.")
+
+ def retrieve_relevant_case(self, scenario_query, top_k=1):
+ """
+ Finds the case matching the query string exactly.
+ Ignores 'top_k' but mimics the return structure [(case_dict, score)].
+ """
+ print(f"DummyRetriever searching for exact match: '{scenario_query}'")
+ for case_dict in self.dataset:
+ if case_dict["clinical_presentation"] == scenario_query:
+ print(f"DummyRetriever found match: {case_dict['clinical_presentation']}")
+ return [(case_dict, 1.0)]
+
+ print(f"DummyRetriever: No exact match found for '{scenario_query}'")
+ return []
\ No newline at end of file
diff --git a/src/simulator.py b/src/simulator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f79afe531f267ae7d5d9778534b56bf792825add
--- /dev/null
+++ b/src/simulator.py
@@ -0,0 +1,247 @@
+import json
+import pandas as pd
+# Assuming retriever and evaluator classes are in these files:
+from .retriever import ClinicalCaseRetriever, DummyRetriever
+from .evaluator import AnswerEvaluator
+
+class OralExamSimulator:
+ """Main class that coordinates the oral board exam simulation."""
+
+ def __init__(self, retriever, evaluator):
+ if not isinstance(retriever, (ClinicalCaseRetriever, DummyRetriever)):
+ raise TypeError("Retriever must be an instance of ClinicalCaseRetriever or DummyRetriever")
+ if not isinstance(evaluator, AnswerEvaluator):
+ raise TypeError("Evaluator must be an instance of AnswerEvaluator")
+
+ self.retriever = retriever
+ self.evaluator = evaluator
+ self.current_case = None
+ self.current_question_idx = 0
+ self.session_history = []
+
+ def start_new_case(self, clinical_query=None, case_idx=None):
+ """
+ Initialize a new exam case based on query or direct selection.
+
+ Args:
+ clinical_query (str, optional): Text description of the desired case topic.
+ case_idx (int, optional): Direct index of the case to use from the retriever's dataset.
+
+ Returns:
+ dict: Contains case info and first question, or an error message.
+ """
+ print("-" * 50)
+ print(f"Attempting to start new case | Query: '{clinical_query}' | Index: {case_idx}")
+
+ # Reset state for the new case
+ self.current_case = None
+ self.current_question_idx = 0
+ self.session_history = []
+
+ # Case selection logic
+ retrieved_info = None # Use a temporary variable
+ if case_idx is not None:
+ try:
+ # Direct case selection by index
+ # Ensure index is valid
+ if 0 <= int(case_idx) < len(self.retriever.dataset):
+ self.current_case = self.retriever.dataset[int(case_idx)]
+ similarity_score = 1.0 # Direct selection implies perfect 'match'
+ print(f"Selected case by index {case_idx}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
+ retrieved_info = (self.current_case, similarity_score)
+ else:
+ print(f"Error: Invalid case index {case_idx}. Must be between 0 and {len(self.retriever.dataset)-1}.")
+ return {"error": f"Invalid case index: {case_idx}"}
+ except Exception as e:
+ print(f"Error selecting case by index {case_idx}: {e}")
+ return {"error": f"Failed to select case by index: {e}"}
+
+ elif clinical_query:
+ # RAG-based retrieval
+ try:
+ # retrieve_relevant_case now returns a list of tuples: [(case_dict, score), ...]
+ retrieved_results = self.retriever.retrieve_relevant_case(clinical_query, top_k=1)
+ if retrieved_results: # Check if list is not empty
+ retrieved_info = retrieved_results[0] # Get the first tuple (case_dict, score)
+ self.current_case = retrieved_info[0]
+ similarity_score = retrieved_info[1]
+ print(f"Retrieved case via query ('{clinical_query}') with score {similarity_score:.4f}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
+ else:
+ print(f"Error: No case found for query: '{clinical_query}'")
+ return {"error": f"No relevant case found for query: {clinical_query}"}
+ except Exception as e:
+ print(f"Error retrieving case for query '{clinical_query}': {e}")
+ return {"error": f"Failed to retrieve case by query: {e}"}
+ else:
+ # No selection method provided
+ print("Error: Must provide either a clinical query or a case index.")
+ return {"error": "Please provide either a clinical query or case index."}
+
+ # --- Post-selection setup ---
+ if self.current_case is None:
+ # This should ideally be caught above, but double-check
+ print("Error: Failed to set current_case.")
+ return {"error": "Failed to load the selected case."}
+
+ # Validate case structure
+ if 'questions' not in self.current_case or 'answers' not in self.current_case or \
+ not isinstance(self.current_case['questions'], list) or \
+ not isinstance(self.current_case['answers'], list) or \
+ len(self.current_case['questions']) != len(self.current_case['answers']):
+ print(f"Error: Invalid case structure for case ID {self.current_case.get('case_id', 'N/A')}. Mismatched or missing Q/A lists.")
+ return {"error": "Selected case has invalid format."}
+
+ if not self.current_case['questions']:
+ print(f"Warning: Selected case ID {self.current_case.get('case_id', 'N/A')} has no questions.")
+ # Decide how to handle this - error or proceed? Let's return an error for now.
+ return {"error": "Selected case contains no questions."}
+
+
+ # Start a new session record
+ self.session_history.append({
+ "role": "system",
+ "content": f"Clinical scenario started: {self.current_case.get('clinical_presentation', 'Unknown Presentation')} (Case ID: {self.current_case.get('case_id', 'N/A')})"
+ })
+
+ # Get the first question
+ first_question = self.current_case['questions'][0]
+
+ # Record this interaction
+ self.session_history.append({
+ "role": "examiner",
+ "content": first_question
+ })
+
+ print(f"Case successfully started. Total questions: {len(self.current_case['questions'])}")
+ print("-" * 50)
+
+ return {
+ "case_id": self.current_case.get('case_id', 'unknown'),
+ "clinical_presentation": self.current_case.get('clinical_presentation', 'Unknown'),
+ "similarity_score": similarity_score, # Use the score from retrieval
+ "current_question": first_question,
+ "question_number": 1,
+ "total_questions": len(self.current_case['questions'])
+ }
+
+ def process_user_response(self, response):
+ """
+ Process the user's answer, get feedback, and return the next question or completion status.
+
+ Args:
+ response (str): User's answer text.
+
+ Returns:
+ dict: Contains feedback, expected answer, completion status, and next question (if applicable), or an error message.
+ """
+ if self.current_case is None:
+ print("Error: No active case.")
+ return {"error": "No active case. Please start a new case first."}
+
+ if self.current_question_idx >= len(self.current_case['questions']):
+ print("Error: Attempting to process response when case is already complete.")
+ return {"error": "Case already completed."}
+
+ print("-" * 50)
+ current_q_num = self.current_question_idx + 1
+ total_q = len(self.current_case['questions'])
+ print(f"Processing response for Question {current_q_num}/{total_q}")
+ print(f"User Response: {response}")
+
+ # Save the user's response to history
+ self.session_history.append({
+ "role": "resident",
+ "content": response
+ })
+
+ # Get the expected answer for the current question
+ expected_answer = self.current_case['answers'][self.current_question_idx]
+ print(f"Expected Answer: {expected_answer}")
+
+ # Evaluate the answer
+ feedback = self.evaluator.evaluate_answer(
+ response,
+ expected_answer,
+ clinical_context = f"Regarding the case '{self.current_case.get('clinical_presentation', 'N/A')}'"
+ )
+ print(f"Generated Feedback: {feedback}")
+
+
+ # Add feedback to history
+ self.session_history.append({
+ "role": "feedback",
+ "content": feedback
+ })
+
+ # Move to the next question index
+ self.current_question_idx += 1
+
+ # Check if the case is complete
+ is_complete = self.current_question_idx >= len(self.current_case['questions'])
+
+ result = {
+ "feedback": feedback,
+ "expected_answer": expected_answer,
+ "is_complete": is_complete,
+ "question_number": self.current_question_idx
+ }
+
+ # Add next question if not complete
+ if not is_complete:
+ next_question = self.current_case['questions'][self.current_question_idx]
+ result["next_question"] = next_question
+ result["total_questions"] = total_q
+
+ # Add next question to history
+ self.session_history.append({
+ "role": "examiner",
+ "content": next_question
+ })
+ print(f"Next question ({result['question_number']}/{total_q}): {next_question}")
+ else:
+ print("Case completed.")
+ summary = self.generate_session_summary()
+ result["session_summary"] = summary
+ self.session_history.append({
+ "role": "system",
+ "content": "End of clinical scenario."
+ })
+
+
+ print("-" * 50)
+ return result
+
+ def generate_session_summary(self):
+ """Generate a summary dictionary of the completed session."""
+ if not self.current_case or not self.session_history:
+ return {"error": "No active or completed session to summarize."}
+
+ # Simple summary structure
+ return {
+ "case_id": self.current_case.get('case_id', 'N/A'),
+ "case": self.current_case.get('clinical_presentation', 'Unknown'),
+ "total_questions_in_case": len(self.current_case.get('questions', [])),
+ "interaction_history": self.session_history # Include the full log
+ }
+
+ def save_session(self, filepath):
+ """Save the current session summary to a JSON file."""
+ summary = self.generate_session_summary()
+ if "error" in summary:
+ print(f"Error generating summary for saving: {summary['error']}")
+ return {"error": "No session to save"}
+
+ try:
+ # Add a timestamp to the saved data
+ summary["timestamp"] = pd.Timestamp.now().isoformat()
+
+ # Ensure directory exists
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+
+ with open(filepath, 'w') as f:
+ json.dump(summary, f, indent=2)
+ print(f"Session saved successfully to {filepath}")
+ return {"status": "Session saved successfully"}
+ except Exception as e:
+ print(f"Error saving session to {filepath}: {e}")
+ return {"error": f"Failed to save session: {e}"}
\ No newline at end of file
diff --git a/src/synthetic_generator.py b/src/synthetic_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..409efe96efb2c47ce08a52fa0215ef0be6df0083
--- /dev/null
+++ b/src/synthetic_generator.py
@@ -0,0 +1,125 @@
+import pandas as pd
+import numpy as np
+import torch
+import re
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+def generate_synthetic_case(clinical_query, model_id="meta-llama/Llama-3.2-3B-Instruct", max_tokens=800):
+ """Generate a synthetic clinical case with examiner questions and expected answers."""
+ print(f"Generating synthetic case for '{clinical_query}' using {model_id}...")
+ gen_tokenizer = None
+ gen_model = None
+ try:
+ # Initialize generator model components
+ gen_tokenizer = AutoTokenizer.from_pretrained(model_id)
+ gen_model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.float16,
+ device_map="auto"
+ )
+ gen_model.eval()
+ device = gen_model.device
+ if gen_tokenizer.pad_token is None:
+ gen_tokenizer.pad_token = gen_tokenizer.eos_token
+
+ except Exception as e:
+ print(f"Error initializing generator model {model_id}: {e}")
+ return None
+
+ prompt = f"""[INST] You are a board-certified general surgeon simulating a clinical oral board exam.
+ Create a synthetic case on the topic: "{clinical_query}".
+ Start by describing the initial clinical presentation in 1–2 sentences.
+ Then generate a list of 5–8 examiner questions (Q1, Q2...), each paired with the expected examinee answer (A1, A2...). Ensure Q/A pairs are clearly separated.
+ Output ONLY the presentation and Q&A pairs in this exact format:
+ Clinical Presentation: ...
+
+ Q1: ...
+ A1: ...
+
+ Q2: ...
+ A2: ...
+
+ (continue until Qn/An)
+ Focus on common scenarios and standard knowledge. Avoid overly complex or rare details.
+ [/INST]"""
+
+ output_text = None
+ try:
+ inputs = gen_tokenizer(prompt, return_tensors="pt").to(device)
+ input_ids_length = inputs.input_ids.shape[1]
+
+ with torch.no_grad():
+ outputs = gen_model.generate(
+ inputs.input_ids,
+ max_new_tokens=max_tokens,
+ do_sample=True, # Sample to get potentially varied outputs
+ temperature=0.7,
+ top_p=0.9,
+ pad_token_id=gen_tokenizer.eos_token_id
+ )
+
+ # Decode only the newly generated tokens
+ generated_ids = outputs[0][input_ids_length:]
+ output_text = gen_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
+ print("Synthetic case generation complete.")
+
+ except Exception as e:
+ print(f"Error during synthetic case generation: {e}")
+ finally:
+ # Clean up model resources
+ del gen_model
+ del gen_tokenizer
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ return output_text
+
+def process_synthetic_data(clinical_query, output_text):
+ """Process the raw LLM output text into a structured DataFrame for the DummyRetriever."""
+ # Extract clinical presentation
+ match = re.search(r"Clinical Presentation:(.*?)(?=\n\nQ1:|$)", output_text, re.DOTALL | re.IGNORECASE)
+ clinical_presentation_text = match.group(1).strip() if match else "Synthetic Case: " + clinical_query
+
+ # Extract Q&A pairs
+ qa_pattern = r"Q(\d+):\s*(.*?)\s*A\1:\s*(.*?)(?=\n*Q\d+:|\Z)"
+ qa_matches = re.findall(qa_pattern, output_text, flags=re.DOTALL | re.IGNORECASE)
+
+ qa_list = []
+ for match_tuple in qa_matches:
+ try:
+ q_num = int(match_tuple[0])
+ q_text = match_tuple[1].strip()
+ a_text = match_tuple[2].strip()
+ if q_text and a_text:
+ qa_list.append({'turn_id': q_num, 'question': q_text, 'answer': a_text})
+ except (IndexError, ValueError) as e:
+ print(f"Warning: Skipping malformed Q/A match: {match_tuple} due to {e}")
+
+ if not qa_list:
+ print("Warning: No valid Q&A pairs extracted from synthetic text.")
+ return pd.DataFrame()
+
+ qa_list.sort(key=lambda item: item['turn_id'])
+
+ rows = []
+ for item in qa_list:
+ rows.append({
+ 'case_id': 'SYNTH_01',
+ 'clinical_presentation': clinical_query, # Use query as presentation title
+ 'turn_id': item['turn_id'],
+ 'question': item['question'],
+ 'answer': item['answer']
+ })
+
+ df_synthetic = pd.DataFrame(rows)
+
+ if not df_synthetic.empty and clinical_presentation_text:
+ # Find the index of the first turn
+ first_turn_index = df_synthetic[df_synthetic['turn_id'] == 1].index
+ if not first_turn_index.empty:
+ idx = first_turn_index[0]
+ df_synthetic.loc[idx, 'question'] = clinical_presentation_text + " " + df_synthetic.loc[idx, 'question']
+ else:
+ print("Warning: Could not find turn_id 1 to prepend presentation.")
+
+ print(f"Processed synthetic data into DataFrame with {len(df_synthetic)} turns.")
+ return df_synthetic
\ No newline at end of file