VuvanAn commited on
Commit
cc37925
·
verified ·
1 Parent(s): b602d4f

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. .gitignore +9 -0
  3. .gradio/certificate.pem +31 -0
  4. Papers/12911_2025_Article_2954.pdf +3 -0
  5. Papers/2502.11371v1.pdf +3 -0
  6. Papers/ICRALM.pdf +3 -0
  7. Papers/REALM RAG-Driven Enhancement of Multimodal Electronic Health Records Analysis via Large Language Models.pdf +3 -0
  8. Papers/bioengineering-12-00631.pdf +3 -0
  9. README.md +3 -9
  10. docs/data_description.md +173 -0
  11. notebook/An/master/README.md +20 -0
  12. notebook/An/master/__pycache__/app.cpython-311.pyc +0 -0
  13. notebook/An/master/__pycache__/app.cpython-313.pyc +0 -0
  14. notebook/An/master/__pycache__/utils.cpython-313.pyc +0 -0
  15. notebook/An/master/app.py +155 -0
  16. notebook/An/master/config.yaml +25 -0
  17. notebook/An/master/rag_pipeline/__init__.py +8 -0
  18. notebook/An/master/rag_pipeline/__pycache__/__init__.cpython-313.pyc +0 -0
  19. notebook/An/master/rag_pipeline/data_ingest/__pycache__/loader.cpython-313.pyc +0 -0
  20. notebook/An/master/rag_pipeline/data_ingest/loader.py +40 -0
  21. notebook/An/master/rag_pipeline/data_ingest/parser.py +0 -0
  22. notebook/An/master/rag_pipeline/generation/__pycache__/llm_wrapper.cpython-313.pyc +0 -0
  23. notebook/An/master/rag_pipeline/generation/__pycache__/prompt_template.cpython-313.pyc +0 -0
  24. notebook/An/master/rag_pipeline/generation/llm_wrapper.py +59 -0
  25. notebook/An/master/rag_pipeline/generation/prompt_template.py +115 -0
  26. notebook/An/master/rag_pipeline/indexing/chunking/__pycache__/markdown.cpython-313.pyc +0 -0
  27. notebook/An/master/rag_pipeline/indexing/chunking/__pycache__/recursive.cpython-313.pyc +0 -0
  28. notebook/An/master/rag_pipeline/indexing/chunking/markdown.py +54 -0
  29. notebook/An/master/rag_pipeline/indexing/chunking/recursive.py +30 -0
  30. notebook/An/master/rag_pipeline/indexing/embedding/__pycache__/embedding.cpython-313.pyc +0 -0
  31. notebook/An/master/rag_pipeline/indexing/embedding/embedding.py +23 -0
  32. notebook/An/master/rag_pipeline/retrieval/__pycache__/reranker.cpython-313.pyc +0 -0
  33. notebook/An/master/rag_pipeline/retrieval/__pycache__/vector_retriever.cpython-313.pyc +0 -0
  34. notebook/An/master/rag_pipeline/retrieval/graph_retriever.py +4 -0
  35. notebook/An/master/rag_pipeline/retrieval/hybrid_retriever.py +0 -0
  36. notebook/An/master/rag_pipeline/retrieval/reranker.py +8 -0
  37. notebook/An/master/rag_pipeline/retrieval/vector_retriever.py +38 -0
  38. notebook/An/master/test/__pycache__/data_ingest.cpython-313.pyc +0 -0
  39. notebook/An/master/test/__pycache__/eval_lm.cpython-313.pyc +0 -0
  40. notebook/An/master/test/__pycache__/eval_qa.cpython-313.pyc +0 -0
  41. notebook/An/master/test/__pycache__/prepare_retrieve.cpython-313.pyc +0 -0
  42. notebook/An/master/test/__pycache__/test_llm.cpython-313.pyc +0 -0
  43. notebook/An/master/test/__pycache__/test_retrieve.cpython-313.pyc +0 -0
  44. notebook/An/master/test/chatbot_inference.py +23 -0
  45. notebook/An/master/test/data_ingest.py +78 -0
  46. notebook/An/master/test/eval_lm.py +87 -0
  47. notebook/An/master/test/eval_qa.py +106 -0
  48. notebook/An/master/test/prepare_retrieve.py +50 -0
  49. notebook/An/master/test/test_llm.py +9 -0
  50. notebook/An/master/test/test_retrieve.py +39 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Papers/12911_2025_Article_2954.pdf filter=lfs diff=lfs merge=lfs -text
37
+ Papers/2502.11371v1.pdf filter=lfs diff=lfs merge=lfs -text
38
+ Papers/bioengineering-12-00631.pdf filter=lfs diff=lfs merge=lfs -text
39
+ Papers/ICRALM.pdf filter=lfs diff=lfs merge=lfs -text
40
+ Papers/REALM[[:space:]]RAG-Driven[[:space:]]Enhancement[[:space:]]of[[:space:]]Multimodal[[:space:]]Electronic[[:space:]]Health[[:space:]]Records[[:space:]]Analysis[[:space:]]via[[:space:]]Large[[:space:]]Language[[:space:]]Models.pdf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ dataset
2
+ *.pptx
3
+ *.txt
4
+ log*
5
+ __pycache__/
6
+ _n*
7
+ *.pkl
8
+ rag_index*
9
+ vectorstore_*
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
Papers/12911_2025_Article_2954.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:035d91124fda070e92971f08d213ddbfa350724a7597779120e71700e12825be
3
+ size 2174791
Papers/2502.11371v1.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bdb8a6f2aae4646b2226ab9c4a979bf8ab9ecbf850915c7b71fa2fd8dd0ee26
3
+ size 1054366
Papers/ICRALM.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92f8aaad658caff829f70e46afac29d5c97f54a393003dbd14e118ab3f274518
3
+ size 947657
Papers/REALM RAG-Driven Enhancement of Multimodal Electronic Health Records Analysis via Large Language Models.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bec8dc765ee37395758016b553394d831158d5dc3fb6bd0a17a491d20254181a
3
+ size 1718547
Papers/bioengineering-12-00631.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3f319d104e44df2d569a2b77761fee70abe1117fa948ba3011c7873bd2a8087
3
+ size 907982
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Vimedllm
3
- emoji: 👁
4
- colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: vimedllm
3
+ app_file: notebook/An/master
 
 
4
  sdk: gradio
5
+ sdk_version: 5.41.1
 
 
6
  ---
 
 
docs/data_description.md ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Description
2
+
3
+ ## Overview
4
+
5
+ This document describes the datasets used in this research project.
6
+
7
+ ## Dataset 1: ViMedAQA (heart related question filtered)
8
+
9
+ ### Description
10
+
11
+ ViMedAQA: A Vietnamese Medical Abstractive Question-Answering Dataset and Findings of Large Language Model
12
+
13
+ ### Source
14
+
15
+ - **URL**: https://huggingface.co/datasets/tmnam20/ViMedAQA/viewer/all/train
16
+ - **Paper**: https://aclanthology.org/2024.acl-srw.31.pdf
17
+ - **License**: [License information]
18
+
19
+ ### Statistics
20
+
21
+ - **Total samples**: 1456
22
+ - **Average text length**: [Number] tokens
23
+ - **Max text length**: [Number] tokens
24
+ - **Min text length**: [Number] tokens
25
+
26
+ ### Format
27
+
28
+ ```json
29
+ {
30
+ "question_idx": "body-part_2201",
31
+ "question": "Khi hình dạng liềm ở góc móng chân tay biến mất thì có thể là dấu hiệu của những tình trạng nào?",
32
+ "answer": "Khi hình dạng liềm ở gốc móng biến mất thì có thể là dấu hiệu của suy dinh dưỡng, trầm cảm hay thiếu máu.",
33
+ "context": "Bạn có nhìn thấy những đường cong nhỏ tròn màu trắng ở gốc móng tay của bạn nhưng không phải ai cũng có chúng. Hầu hết sự có mặt hay không có chúng không có nghĩa lý gì và chúng có thể được ẩn dưới da của bạn. Nếu chúng biến mất, đó có thể là dấu hiệu của tình trạng: - Suy dinh dưỡng.\n- Trầm cảm.\n- Thiếu máu.",
34
+ "title": "Bất thường của móng tay chân - Móng không có hình liềm ở gốc móng",
35
+ "keyword": "Móng tay chân",
36
+ "topic": 0,
37
+ "article_url": "https://youmed.vn/tin-tuc/nhung-bat-thuong-ve-mong-tay-chan/",
38
+ "author": "Bác sĩ Hoàng Thị Việt Trinh",
39
+ "author_url": "https://youmed.vn/tin-tuc/bac-si/bac-si-hoang-thi-viet-trinh/"
40
+ }
41
+
42
+ ```
43
+
44
+ ### Preprocessing Steps
45
+
46
+ 1. **Text cleaning**: Remove special characters, normalize whitespace
47
+ 2. **Tokenization**: Using [tokenizer name]
48
+ 3. **Length filtering**: Remove texts shorter than [X] tokens
49
+ 4. **Label encoding**: Convert labels to numeric format
50
+ 5. **Data splitting**: 80% train, 10% validation, 10% test
51
+
52
+ ## Dataset 2: MedMCQA (heart related filtered)
53
+
54
+ ### Description
55
+
56
+ MedMCQA : A Large-scale Multi-Subject Multi-Choice Dataset for Medical domain Question Answering
57
+
58
+ ### Source
59
+
60
+ - **URL**: https://github.com/medmcqa/medmcqa
61
+ - **Paper**: https://proceedings.mlr.press/v174/pal22a/pal22a.pdf
62
+ - **License**: [License information]
63
+
64
+ ### Statistics
65
+
66
+ - **Total samples**: 2144
67
+
68
+ ### Format
69
+
70
+ ```json
71
+ {
72
+ "id": "405b7c79-b6ac-4407-977c-e5595bba56c4",
73
+ "question": "A 46-year-old man presents with diffuse chest pain at rest and recent history of cough, fever, and rhinorrhea lasting for 3 days.",
74
+ "options": {
75
+ "opa": "Acute pericarditis",
76
+ "opb": "Constrictive pericarditis",
77
+ "opc": "Takotsubo-cardiomyopathy",
78
+ "opd": "Cor pulmonale"
79
+ },
80
+ "correct_option": 0,
81
+ "choice_type": "single",
82
+ "explanation": "Ans. (a) Acute pericarditis. The tracing reveals sinus rhythm at approximately 75 beats/min. The PR interval is prolonged to 200 milliseconds, consistent with borderline first-degree AV block. The QRS axis and intervals are normal. ST elevations with concave upward morphology are seen in I and aVL, II and aVF, and V2 through V6. No Q waves are present. Furthermore, subtle PR-segment depression is seen in leads I and II. The differential diagnosis for ST-segment elevation includes, among other things, acute myocardial infarction, pericarditis, and left ventricular aneurysm. In this case, the upward concavity of the ST segment, the PR-segment depression, the lack of Q waves, and the diffuse nature of the ST-segment elevation in more than one coronary artery distribution make pericarditis the likely etiology. Patients with pericarditis will complain of chest pain, typically described as sharp and pleuritic. Radiation is to the trapezius ridge. The pain is improved with sitting up and leaning forward and worsened by leaning backward.",
83
+ "subject_name": "Medicine",
84
+ "topic_name": "Electrocardiography"
85
+ }
86
+ ```
87
+
88
+ ### Preprocessing Steps
89
+ [List preprocessing steps]
90
+
91
+ ## Dataset 3: MedAB QA
92
+
93
+ ### Description
94
+
95
+ The crawled QA dataset from the online examination.
96
+
97
+ ### Statistics
98
+
99
+ - **Total samples**: 1150
100
+
101
+ ### Format
102
+
103
+ ```json
104
+ {
105
+ "question": "Áp lực tĩnh mạch trung tâm được đo ở............và thường bằng............:",
106
+ "options": {
107
+ "A": "Nhĩ trái; 0 mmHg",
108
+ "B": "Nhĩ phải; 12 cm H2O",
109
+ "C": "Tĩnh mạch chủ trên; -2 mmHg",
110
+ "D": "Tĩnh mạch dưới đòn; 0 mmHg",
111
+ "E": "Nhĩ phải; 0 mmHg"
112
+ },
113
+ "answer": "B"
114
+ }
115
+ ```
116
+
117
+ ## Dataset 4: Mimic_ex
118
+
119
+ ### Description
120
+
121
+ Mimic_ex: A dataset derived from the MIMIC-III database, focusing on medical examinations and related data.
122
+
123
+ ### Source
124
+
125
+ - **URL**: https://physionet.org/content/mimiciii/1.4/
126
+ - **Paper**: https://www.nature.com/articles/sdata201635
127
+ - **License**: [License information]
128
+
129
+ ### Statistics
130
+
131
+ - **Total samples**: 44914
132
+
133
+ ### Format
134
+
135
+ ```txt
136
+ baby girl is a 1,385 gram, former 30 and week premature baby, born to an 18 year old, gravida i, para 0, now i, mother with prenatal serologies as follows: a positive, antibody negative, rpr nonreactive, hepatitis b surface antigen negative; gbs unknown. pregnancy was complicated by pprom on when the mother was transferred from hospital to . mother received betamethasone times two as well as ampicillin and erythromycin. she progressed to a spontaneous vaginal delivery on the morning of . the baby emerged vigorous with spontaneous cry; apgars of eight and nine. she was warm, dried and bulb suctioned in the delivery room and brought to the neonatal intensive care unit for further management for prematurity. physical examination: weight 1,385 grams (25th to 50th percentile); length 38 cms (10 to 25 percentile); head circumference 27.5 cms (10 to 25 percentile). she was an active, alert infant, pink, appropriate for gestational age of 31 weeks. anterior fontanel was open and flat with some molding and caput. no dysmorphism. lungs clear to auscultation. heart regular rate and rhythm without murmurs. abdomen was soft without hepatosplenomegaly or masses. hips were stable. premature female genitalia. extremities were well perfused. hospital course: 1.) respiratory: baby girl remained stable on room air throughout her neonatal intensive care unit stay at . she had one apnea and bradycardia episode on day of life five, requiring mild stimulation. 2.) cardiovascular: baby girl had seemed hemodynamically stable throughout her neonatal intensive care unit stay. she had no murmurs on examination. 3.) fluids, electrolytes and nutrition: baby girl had gradually been advanced to total fluids of 150 cc per kg per day; currently tolerating breast milk 22, maintaining good blood glucose. her admission weight was 1,385 grams; her weight on day of life seven prior to discharge was 1,445 grams. gastrointestinal: baby girl ' bilirubin level peaked on day of life three at 8.3, at which time phototherapy was initiated. subsequently, her bilirubin level was 4.2 on day of life six, at which time the phototherapy was discontinued. her rebound bili on day of life seven was 5.1. infectious disease: baby girl was initiated on ampicillin and gentamycin for rule out sepsis. her blood culture remained negative at 48 hours at which time the antibiotics were discontinued. hematology: the patient's initial hematocrit was 42.8 and required no transfusions during this admission. neurology: baby girl had a screening head ultrasound on day of life seven which was negative. condition at transfer: baby girl has been stable on room air and hemodynamically stable, tolerating full feeds of breast milk 22. discharge disposition: baby girl is being discharged to special care nursery. care and recommendations: feeds at discharge: total fluids of 150 cc per kg per day with breast milk 24. medications: none. state newborn screen: sent. follow-up appointment: recommended in two to three days after discharge from the neonatal intensive care unit. discharge diagnoses: prematurity at 31 weeks. rule out sepsis. , m.d. dictated by: medquist36 Procedure: Parenteral infusion of concentrated nutritional substances Enteral infusion of concentrated nutritional substances Other phototherapy Diagnoses: Observation for suspected infectious condition Single liveborn, born in hospital, delivered without mention of cesarean section Neonatal jaundice associated with preterm delivery Other preterm infants, 1,250-1,499 grams 29-30 completed weeks of gestation
137
+ allergies: penicillins attending: chief complaint: cc: major surgical or invasive procedure: stereotactic brain biopsy, neuronavigation guided tumor resection.
138
+ ```
139
+
140
+ ## Dataset 5: YouMed
141
+
142
+ ### Description
143
+
144
+ YouMed: Crawled from QA page of YouMed Website
145
+
146
+ ### Source
147
+
148
+ - **URL**: https://youmed.vn/
149
+
150
+ ### Statistics
151
+
152
+ - **Total samples**: 309
153
+
154
+ ## Dataset 6: ViWiki (heart relate article filtered)
155
+
156
+ ### Description
157
+
158
+ ViWiki: Crawled from the Vi Wikipedia website
159
+
160
+ ### Source
161
+
162
+ - **URL**: https://vi.wikipedia.org/wiki
163
+
164
+ ### Statistics
165
+
166
+ - **Total samples**: 250
167
+
168
+
169
+ ## References
170
+
171
+ 1. [Dataset paper citation]
172
+ 2. [Related work citations]
173
+ 3. [Preprocessing methodology citations]
notebook/An/master/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```
2
+ python -m notebook.An.master.test.data_ingest
3
+ --data_dir notebook/An/master/data \\
4
+ --vectorstore_dir notebook/An/master/knowledge/vectorstore_1 \\
5
+ --embed_model_name alibaba-nlp/gte-multilingual-base \\
6
+ --chunking_strategy recursive \\
7
+ --chunk_size 2048 \\
8
+ --chunk_overlap 512 \\
9
+ --vectorstore faiss
10
+ ```
11
+
12
+ ```
13
+ python -m notebook.An.master.test.test_retrieve
14
+ --query "Heart definition and heart disease"
15
+ --vectorstore_dir notebook/An/master/knowledge/vectorstore_1 \\
16
+ --embed_model_name alibaba-nlp/gte-multilingual-base \\
17
+ --retriever_k 4 \\
18
+ --metric cosine \\
19
+ --threshold 0.5 \\
20
+ ```
notebook/An/master/__pycache__/app.cpython-311.pyc ADDED
Binary file (8.21 kB). View file
 
notebook/An/master/__pycache__/app.cpython-313.pyc ADDED
Binary file (7.09 kB). View file
 
notebook/An/master/__pycache__/utils.cpython-313.pyc ADDED
Binary file (10.9 kB). View file
 
notebook/An/master/app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datetime import datetime
3
+
4
+ from .rag_pipeline import ChatAssistant, get_embeddings, vretrieve, retrieve_chatbot_prompt, request_retrieve_prompt
5
+ from .utils import load_local
6
+
7
+
8
+ # DEVELOPER: Add or remove models here.
9
+ AVAILABLE_MODELS = {
10
+ # "mistral large (mistral)": ("mistral-large-2", "mistral"),
11
+ "mistral medium (mistral)": ("mistral-medium", "mistral"),
12
+ "mistral small (mistral)": ("mistral-small", "mistral"),
13
+ "llama3 8B" : ("llama3:8b", "ollama"),
14
+ "llama3.1 8B": ("llama3.1:8b", "ollama"),
15
+ "gpt-oss 20B": ("gpt-oss-20b", "ollama"),
16
+ "gemma3 12B": ("gemma3:12b", "ollama"),
17
+ "gpt 4o mini": ("gpt-4o-mini", "openai"),
18
+ "gpt 4o": ("gpt-4o", "openai"),
19
+ }
20
+ DEFAULT_MODEL_KEY = "mistral medium (mistral)"
21
+
22
+ EMBEDDING_MODEL_ID = "alibaba-nlp/gte-multilingual-base"
23
+ VECTORSTORE_PATH = "notebook/An/master/knowledge/vectorstore_full"
24
+ LOG_FILE_PATH = "log.txt"
25
+ MAX_HISTORY_CONVERSATION = 50
26
+
27
+ # System prompt for the medical assistant
28
+ sys = """
29
+ You are an Medical Assistant specialized in providing information and answering questions related to healthcare and medicine.
30
+ You must answer professionally and empathetically, taking into account the user's feelings and concerns.
31
+ """
32
+
33
+ # --- Initial Setup (runs once) ---
34
+ print("Initializing models and data...")
35
+ embedding_model = get_embeddings(EMBEDDING_MODEL_ID, show_progress=False)
36
+ vectorstore, docs = load_local(VECTORSTORE_PATH, embedding_model)
37
+ print("Initialization complete.")
38
+
39
+
40
+ # --- Helper Functions ---
41
+ def log(log_txt: str):
42
+ """Appends a log entry to the log file."""
43
+ with open(LOG_FILE_PATH, "a", encoding="utf-8") as log_file:
44
+ log_file.write(log_txt + "\n")
45
+
46
+
47
+ # --- Core Chatbot Logic ---
48
+ def chatbot_logic(message: str, history: list, selected_model_key: str):
49
+ """
50
+ Handles the main logic for receiving a message, performing RAG, and generating a response.
51
+ """
52
+ # 1. Look up the model_id and model_provider from the selected key
53
+ model_id, model_provider = AVAILABLE_MODELS[selected_model_key]
54
+
55
+ log(f"** Current time **: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
56
+ log(f"** User message **: {message}")
57
+ log(f"** Using Model **: {model_id} ({model_provider})")
58
+
59
+ # Initialize the assistant with the specified model for this request
60
+ try:
61
+ chat_assistant = ChatAssistant(model_id, model_provider)
62
+ except Exception as e:
63
+ yield f"Error: Could not initialize the model. Please check the ID and provider. Details: {e}"
64
+ return
65
+
66
+ # --- RAG Pipeline ---
67
+ # 2. Format conversation history for context
68
+ history = history[-MAX_HISTORY_CONVERSATION:]
69
+ conversation = "".join(f"User: {user_msg}\nBot: {bot_msg}\n" for user_msg, bot_msg in history)
70
+ query_for_rag = conversation + f"User: {message}\nBot:"
71
+
72
+ # 3. Generate a search query from the conversation
73
+ rag_query = chat_assistant.get_response(request_retrieve_prompt.format(role="user", conversation=query_for_rag))
74
+ rag_query = rag_query[rag_query.lower().rfind("[") + 1: rag_query.rfind("]")]
75
+
76
+ # 4. Retrieve relevant documents if necessary
77
+ if "NO NEED" not in rag_query:
78
+ retrieve_results = vretrieve(rag_query, vectorstore, docs, k=4, metric="mmr", threshold=0.7)
79
+ else:
80
+ retrieve_results = []
81
+
82
+ retrieved_docs = "\n".join([f"Document {i+1}:\n" + doc.page_content for i, doc in enumerate(retrieve_results)])
83
+ log(f"** RAG query **: {rag_query}")
84
+ log(f"** Retrieved documents **:\n{retrieved_docs}")
85
+
86
+ # --- Final Response Generation ---
87
+ # 5. Create the final prompt with retrieved context
88
+ final_prompt = retrieve_chatbot_prompt.format(role="user", documents=retrieved_docs, conversation=query_for_rag)
89
+
90
+ # 6. Stream the response from the LLM
91
+ response = ""
92
+ for token in chat_assistant.get_streaming_response(final_prompt, sys):
93
+ response += token
94
+ yield response
95
+
96
+ log(f"** Bot response **: {response}")
97
+ log("=" * 50 + "\n\n")
98
+
99
+ # --- UI Helper Function ---
100
+ def start_new_chat():
101
+ """Clears the chatbot and input box to start a new conversation."""
102
+ return None, ""
103
+
104
+ # --- Gradio UI ---
105
+ with gr.Blocks(theme="soft") as chatbot_ui:
106
+ gr.Markdown("# MedLLM")
107
+ gr.Markdown("Your conversations are automatically saved to `log.txt` for future reference.")
108
+
109
+ model_selector = gr.Dropdown(
110
+ label="Select Model",
111
+ choices=list(AVAILABLE_MODELS.keys()),
112
+ value=DEFAULT_MODEL_KEY,
113
+ )
114
+
115
+ chatbot = gr.Chatbot(label="Chat Window", height=500, bubble_full_width=False, value=None)
116
+
117
+ with gr.Row():
118
+ new_chat_btn = gr.Button("✨ New Chat")
119
+ msg_input = gr.Textbox(
120
+ label="Your Message",
121
+ placeholder="Type your question here and press Enter...",
122
+ scale=7 # Make the textbox take more space in the row
123
+ )
124
+
125
+ def respond(message, chat_history, selected_model_key):
126
+ """Wrapper function to connect chatbot_logic with Gradio's state."""
127
+ # If chat_history is None (cleared), initialize it as an empty list
128
+ chat_history = chat_history or []
129
+ bot_message_stream = chatbot_logic(message, chat_history, selected_model_key)
130
+ chat_history.append([message, ""])
131
+ for token in bot_message_stream:
132
+ chat_history[-1][1] = token
133
+ yield chat_history
134
+
135
+ # Event handler for submitting a message
136
+ msg_input.submit(
137
+ respond,
138
+ [msg_input, chatbot, model_selector],
139
+ [chatbot]
140
+ ).then(
141
+ lambda: gr.update(value=""), None, [msg_input], queue=False
142
+ )
143
+
144
+ # Event handler for the "New Chat" button
145
+ new_chat_btn.click(
146
+ start_new_chat,
147
+ inputs=None,
148
+ outputs=[chatbot, msg_input],
149
+ queue=False # Use queue=False for instantaneous UI updates
150
+ )
151
+
152
+
153
+ # --- Launch the App ---
154
+ if __name__ == "__main__":
155
+ chatbot_ui.launch(debug=True, share=True)
notebook/An/master/config.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 0.1
2
+
3
+ model:
4
+ name: "llama2:7b"
5
+ temperature: 0.3
6
+ max_tokens: 100000
7
+ provider: "ollama"
8
+ base_url: "http://localhost:11434/v1"
9
+
10
+ rag_config:
11
+ k: 4
12
+ rerank:
13
+ name: "bge-reranker-large"
14
+ model: "BAAI/bge-reranker-large"
15
+ top_n: 100
16
+ embed_model:
17
+ name: "gte-multilingual-base"
18
+ model: "alibaba-nlp/gte-multilingual-base"
19
+ chunk_size: 2048
20
+ chunk_overlap: 512
21
+ similarity_threshold: 0.7
22
+ similarity_metric: "cosine"
23
+
24
+ knowledge:
25
+ vectorstore: "faiss"
notebook/An/master/rag_pipeline/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .generation.llm_wrapper import ChatAssistant
2
+ from .indexing.chunking.recursive import split_document as recursive_chunking
3
+ from .indexing.chunking.markdown import split_document as markdown_chunking
4
+ from .indexing.embedding.embedding import get_embeddings
5
+ from .data_ingest.loader import load_data
6
+ from .generation.prompt_template import *
7
+ from .retrieval.vector_retriever import retrieve as vretrieve
8
+ from .retrieval.reranker import rerank
notebook/An/master/rag_pipeline/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (704 Bytes). View file
 
notebook/An/master/rag_pipeline/data_ingest/__pycache__/loader.cpython-313.pyc ADDED
Binary file (2.09 kB). View file
 
notebook/An/master/rag_pipeline/data_ingest/loader.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from langchain.schema import Document
4
+
5
+ def load_data(data_path: str, file_type: str) -> List[Document]:
6
+ """
7
+ Load knowledge data from a specified path and file type.
8
+ Args:
9
+ data_path: The path to the data.
10
+ file_type: The type of the data.
11
+ Returns:
12
+ A list of documents.
13
+ """
14
+ if file_type == "pdf":
15
+ raise NotImplementedError("PDF loading is not yet implemented.")
16
+ elif file_type == "txt":
17
+ return _load_txt(data_path)
18
+
19
+ def _load_txt(data_path: str) -> List[Document]:
20
+ splits = []
21
+
22
+ if not os.path.isdir(data_path):
23
+ raise FileNotFoundError(f"Error: Directory not found at {data_path}")
24
+
25
+ for file_name in os.listdir(data_path):
26
+ if file_name.endswith('.txt'):
27
+ file_path = os.path.join(data_path, file_name)
28
+
29
+ try:
30
+ with open(file_path, 'r', encoding='utf-8') as f:
31
+ content = f.read()
32
+ metadata = {"source": file_name}
33
+ doc = Document(page_content=content, metadata=metadata)
34
+
35
+ splits.append(doc)
36
+
37
+ except Exception as e:
38
+ print(f"Error reading file {file_path}: {e}")
39
+
40
+ return splits
notebook/An/master/rag_pipeline/data_ingest/parser.py ADDED
File without changes
notebook/An/master/rag_pipeline/generation/__pycache__/llm_wrapper.cpython-313.pyc ADDED
Binary file (2.9 kB). View file
 
notebook/An/master/rag_pipeline/generation/__pycache__/prompt_template.cpython-313.pyc ADDED
Binary file (4 kB). View file
 
notebook/An/master/rag_pipeline/generation/llm_wrapper.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import backoff
3
+
4
+ import os
5
+
6
+ _base_url_ ={
7
+ "ollama": "http://localhost:11434/v1",
8
+ "mistral": "https://api.mistral.ai/v1",
9
+ "openai": "https://api.openai.com/v1",
10
+ }
11
+
12
+ _api_key_ = {
13
+ "ollama": "ollama",
14
+ "mistral": os.getenv("MISTRAL_API_KEY"),
15
+ "openai": os.getenv("OPENAI_API_KEY"),
16
+ }
17
+
18
+ class ChatAssistant:
19
+ def __init__(self, model_name:str, provider:str = "ollama"):
20
+ """
21
+ Args:
22
+ model_name: The name of the model to use.
23
+ provider: The provider of the model. Can be "ollama", "mistral", or "openai".
24
+ """
25
+ self.model_name = model_name
26
+ self.client = OpenAI(
27
+ base_url=_base_url_[provider],
28
+ api_key=_api_key_[provider],
29
+ )
30
+
31
+ @backoff.on_exception(backoff.expo, Exception)
32
+ def get_response(self, user: str, sys: str = ""):
33
+ response = self.client.chat.completions.create(
34
+ model=self.model_name,
35
+ messages=[
36
+ {"role": "system", "content": sys},
37
+ {"role": "user", "content": user},
38
+ ]
39
+ )
40
+ return response.choices[0].message.content
41
+
42
+ @backoff.on_exception(backoff.expo, Exception)
43
+ def get_streaming_response(self, user: str, sys: str = ""):
44
+ """Yields the response token by token (streaming)."""
45
+ response_stream = self.client.chat.completions.create(
46
+ model=self.model_name,
47
+ messages=[
48
+ {"role": "system", "content": sys},
49
+ {"role": "user", "content": user},
50
+ ],
51
+ stream=True
52
+ )
53
+
54
+ # Iterate over the stream of chunks
55
+ for chunk in response_stream:
56
+ # The actual token is in chunk.choices[0].delta.content
57
+ token = chunk.choices[0].delta.content
58
+ if token is not None:
59
+ yield token
notebook/An/master/rag_pipeline/generation/prompt_template.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ multichoice_qa_prompt = """
2
+ -- DOCUMENT --
3
+ {document}
4
+ -- END OF DOCUMENT --
5
+
6
+ -- INSTRUCTION --
7
+ You are a medical expert.
8
+ Given the documents, you must answer the question follow these step by step.
9
+ First, you must read the question and the options, and draft an answer for it based on your knowledge.
10
+ Second, you must read the documents and check if they can help answer the question.
11
+ Third, you cross check the document with your knowledge and the draft answer.
12
+ Finally, you answer the question based on your knowledge and the true documents.
13
+ Your response must end with the letter of the most correct option like: "the answer is A".
14
+ The entire thought must under 500 words long.
15
+ -- END OF INSTRUCTION --
16
+
17
+ -- QUESTION --
18
+ {question}
19
+ {options}
20
+ -- END OF QUESTION --
21
+ """
22
+
23
+ qa_prompt = """
24
+ -- DOCUMENT --
25
+ {document}
26
+ -- END OF DOCUMENT --
27
+
28
+ -- INSTRUCTION --
29
+ You are a medical expert.
30
+ Given the documents, you must answer the question follow these step by step.
31
+ First, you must read the question and draft an answer for it based on your knowledge.
32
+ Second, you must read the documents and check if they can help answer the question.
33
+ Third, you cross check the document with your knowledge and the draft answer.
34
+ Finally, you answer the question based on your knowledge and the true documents concisely.
35
+ Your response must as shortest as possible, in Vietnamese and between brackets like: "[...]".
36
+ -- END OF INSTRUCTION --
37
+
38
+ -- QUESTION --
39
+ {question}
40
+ -- END OF QUESTION --
41
+ """
42
+
43
+ retrieve_chatbot_prompt = """
44
+ You are a medical expert.
45
+ You are having a conversation with a {role} and you have an external documents to help you.
46
+ Continue the conversation based on the chat history, the context information, and not prior knowledge.
47
+ Before use the retrieved chunk, you must check if it is relevant to the user query. If it is not relevant, you must ignore it.
48
+ You use the relevant chunk to answer the question and cite the source inside <<<>>>.
49
+ If you don't know the answer, you must say "I don't know".
50
+ ---------------------
51
+ {documents}
52
+ ---------------------
53
+ Given the documents and not prior knowledge, continue the conversation.
54
+ ---------------------
55
+ {conversation}
56
+ ---------------------
57
+ """
58
+
59
+ request_retrieve_prompt = """
60
+ --- INSTRUCTION ---
61
+ You are having a conversation with a {role}.
62
+ You have to provide a short query to retrieve the documents that you need inside the brackets like: "[...]".
63
+ If it is something do not related to medical field, or something you do not need the external knowledge to answer, you must write "[NO NEED]".
64
+ --- END OF INSTRUCTION ---
65
+
66
+ --- COVERSATION ---
67
+ {conversation}
68
+ --- END OF COVERSATION ---
69
+ """
70
+
71
+ answer_prompt = """
72
+ -- INSTRUCTION --
73
+ You are a medical expert.
74
+ Given the documents below, you must answer the question step by step.
75
+ First, you must read the question.
76
+ Second, you must read the documents and check for it's reliability.
77
+ Third, you cross check with your knowledge.
78
+ Finally, you answer the question based on your knowledge and the true documents.
79
+
80
+ Your answer must UNDER 50 words, write on 1 line and write in Vietnamese.
81
+ -- END OF INSTRUCTION --
82
+
83
+ -- QUESTION --
84
+ {question}
85
+ -- END OF QUESTION --
86
+
87
+ -- DOCUMENT --
88
+ {document}
89
+ -- END OF DOCUMENT --
90
+
91
+ """
92
+
93
+ translate_prompt = """
94
+ [ INSTRUCTION ]
95
+ You are a Medical translator expert.
96
+ Your task is to translate this English question into Vietnamese with EXACTLY the same format and write in 1 line.
97
+ [ END OF INSTRUCTION ]
98
+
99
+ [ QUERY TO TRANSLATE ]
100
+ {query}
101
+ [ END OF QUERY TO TRANSLATE ]
102
+ """
103
+
104
+ pdf2txt_prompt = """
105
+ Rewrite this plain text from pdf file follow the right reading order and these instructions:
106
+ - Use markdown format.
107
+ - Use same language.
108
+ - Keep the content intact.
109
+ - Beautify the table.
110
+ - No talk.
111
+
112
+ [ QUERY ]
113
+ {query}
114
+ [ END OF QUERY ]
115
+ """
notebook/An/master/rag_pipeline/indexing/chunking/__pycache__/markdown.cpython-313.pyc ADDED
Binary file (2.55 kB). View file
 
notebook/An/master/rag_pipeline/indexing/chunking/__pycache__/recursive.cpython-313.pyc ADDED
Binary file (1.52 kB). View file
 
notebook/An/master/rag_pipeline/indexing/chunking/markdown.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.text_splitter import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
2
+ from langchain.schema import Document
3
+ from typing import List
4
+
5
+ def __split_1_document__(document: Document, chunk_size: int, chunk_overlap: int) -> List[Document]:
6
+ headers_to_split_on = [
7
+ ("#", "Header 1"),
8
+ ("##", "Header 2"),
9
+ ("###", "Header 3"),
10
+ ]
11
+
12
+ markdown_splitter = MarkdownHeaderTextSplitter(
13
+ headers_to_split_on=headers_to_split_on,
14
+ strip_headers=False,
15
+ return_each_line=False
16
+ )
17
+
18
+ md_header_splits = markdown_splitter.split_text(document.page_content)
19
+
20
+ for doc in md_header_splits:
21
+ doc.metadata.update(document.metadata)
22
+
23
+ text_splitter = RecursiveCharacterTextSplitter(
24
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
25
+ )
26
+
27
+ final_splits = text_splitter.split_documents(md_header_splits)
28
+
29
+ # Iterate through the final chunks to prepend metadata to the page_content
30
+ for i, doc in enumerate(final_splits):
31
+ header_lines = []
32
+ source_line = f"-- source: {doc.metadata.get('source', 'N/A')}"
33
+
34
+ if 'Header 1' in doc.metadata:
35
+ header_lines.append(doc.metadata['Header 1'])
36
+ if 'Header 2' in doc.metadata:
37
+ header_lines.append(doc.metadata['Header 2'])
38
+ if 'Header 3' in doc.metadata:
39
+ header_lines.append(doc.metadata['Header 3'])
40
+
41
+ header_content = "\n".join(header_lines)
42
+ chunk_header = f"Chunk {i+1}:"
43
+
44
+ # Combine everything into the new page content
45
+ original_content = doc.page_content
46
+ doc.page_content = f"{source_line}\n{header_content}\n{chunk_header}\n{original_content}"
47
+
48
+ return final_splits
49
+
50
+ def split_document(documents: List[Document], chunk_size: int, chunk_overlap: int) -> List[Document]:
51
+ split_documents = []
52
+ for doc in documents:
53
+ split_documents.extend(__split_1_document__(doc, chunk_size, chunk_overlap))
54
+ return split_documents
notebook/An/master/rag_pipeline/indexing/chunking/recursive.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
2
+ from langchain.schema import Document
3
+ from typing import List
4
+
5
+ def __split_1_document__(document: Document, chunk_size: int, chunk_overlap: int) -> List[Document]:
6
+ text_splitter = RecursiveCharacterTextSplitter(
7
+ chunk_size=chunk_size,
8
+ chunk_overlap=chunk_overlap,
9
+ )
10
+
11
+ text_content = document.page_content
12
+ text_chunks = text_splitter.split_text(text_content)
13
+ split_documents = []
14
+
15
+ for i, chunk in enumerate(text_chunks):
16
+ new_metadata = document.metadata.copy()
17
+
18
+ # new_metadata['chunk_number'] = i + 1
19
+
20
+ new_doc = Document(page_content=chunk, metadata=new_metadata)
21
+ split_documents.append(new_doc)
22
+
23
+ return split_documents
24
+
25
+
26
+ def split_document(documents: List[Document], chunk_size: int, chunk_overlap: int) -> List[Document]:
27
+ split_documents = []
28
+ for doc in documents:
29
+ split_documents.extend(__split_1_document__(doc, chunk_size, chunk_overlap))
30
+ return split_documents
notebook/An/master/rag_pipeline/indexing/embedding/__pycache__/embedding.cpython-313.pyc ADDED
Binary file (1.06 kB). View file
 
notebook/An/master/rag_pipeline/indexing/embedding/embedding.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_huggingface import HuggingFaceEmbeddings
2
+
3
+ import torch
4
+
5
+ _model_cache = {}
6
+
7
+ def get_embeddings(model_name: str, show_progress: bool = True) -> HuggingFaceEmbeddings:
8
+ """
9
+ Get the embeddings model. Cache available.
10
+ Args:
11
+ model_name: The name of the model.
12
+ Returns:
13
+ The embeddings model.
14
+ """
15
+ if model_name not in _model_cache:
16
+ embeddings = HuggingFaceEmbeddings(
17
+ model_name=model_name,
18
+ show_progress=show_progress,
19
+ model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'trust_remote_code':True},
20
+ encode_kwargs={'batch_size': 15}
21
+ )
22
+ _model_cache[model_name] = embeddings
23
+ return _model_cache[model_name]
notebook/An/master/rag_pipeline/retrieval/__pycache__/reranker.cpython-313.pyc ADDED
Binary file (504 Bytes). View file
 
notebook/An/master/rag_pipeline/retrieval/__pycache__/vector_retriever.cpython-313.pyc ADDED
Binary file (2.25 kB). View file
 
notebook/An/master/rag_pipeline/retrieval/graph_retriever.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from typing import List, Any
2
+
3
+ def retrieve(query: str, graphstore: Any = None) -> List[str]:
4
+ pass
notebook/An/master/rag_pipeline/retrieval/hybrid_retriever.py ADDED
File without changes
notebook/An/master/rag_pipeline/retrieval/reranker.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from typing import List
4
+
5
+ from langchain.schema import Document
6
+
7
+ def rerank(docs: List[Document]) -> List[Document]:
8
+ return docs
notebook/An/master/rag_pipeline/retrieval/vector_retriever.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from langchain.schema import Document
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+
5
+ from .reranker import rerank
6
+
7
+ from typing import List, Any
8
+
9
+ def retrieve(query: str, vectorstore: FAISS, docs: List[Document] = None, k: int = 4, metric: str = "cosine", threshold: float = 0.5, reranker: Any = None) -> List[Document]:
10
+ """
11
+ Retrieve documents from the vectorstore based on the query and metric.
12
+ Args:
13
+ query: The query to search for.
14
+ metric: The metric to use for retrieval.
15
+ vectorstore: The vectorstore to search in.
16
+ k: The number of documents to retrieve.
17
+ threshold: The threshold for the metric to use for retrieval.
18
+ reranker: The reranker to use for reranking the retrieved documents.
19
+ Returns:
20
+ A list of documents.
21
+ """
22
+ if metric == "cosine":
23
+ docs = vectorstore.similarity_search_with_score(query, k=k)
24
+ docs = [doc for doc, score in docs if score > threshold]
25
+ elif metric == "mmr":
26
+ docs = vectorstore.max_marginal_relevance_search(query, k=k)
27
+ elif metric == "bm25":
28
+ from langchain_community.retrievers import BM25Retriever
29
+ if docs is None:
30
+ raise ValueError("Documents not available. BM25 requires ingested or loaded documents.")
31
+ bm25_retriever = BM25Retriever.from_documents(docs)
32
+ docs = bm25_retriever.get_relevant_documents(query, k=k)
33
+ else:
34
+ raise ValueError(f"Unsupported metric: '{metric}'. Supported metrics are 'similarity', 'mmr', and 'bm25'.")
35
+
36
+ if (reranker != None):
37
+ return rerank(docs)
38
+ return docs
notebook/An/master/test/__pycache__/data_ingest.cpython-313.pyc ADDED
Binary file (3.98 kB). View file
 
notebook/An/master/test/__pycache__/eval_lm.cpython-313.pyc ADDED
Binary file (5.68 kB). View file
 
notebook/An/master/test/__pycache__/eval_qa.cpython-313.pyc ADDED
Binary file (6.49 kB). View file
 
notebook/An/master/test/__pycache__/prepare_retrieve.cpython-313.pyc ADDED
Binary file (3.71 kB). View file
 
notebook/An/master/test/__pycache__/test_llm.cpython-313.pyc ADDED
Binary file (603 Bytes). View file
 
notebook/An/master/test/__pycache__/test_retrieve.cpython-313.pyc ADDED
Binary file (2.32 kB). View file
 
notebook/An/master/test/chatbot_inference.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rag_pipeline import get_embeddings, vretrieve, rerank
2
+ from utils import load_local
3
+
4
+ import argparse
5
+
6
+ def inference():
7
+ embed_model = get_embeddings(args.embed_model_name)
8
+ vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
9
+ retrieve_results = vretrieve(args.query, vectorstore, docs, args.retriever_k, args.metric, args.threshold)
10
+
11
+ retrieve_results = rerank(retrieve_results)
12
+
13
+ print(retrieve_results)
14
+
15
+ def conversation():
16
+ while True:
17
+ query = input("User: ")
18
+ if query == "exit":
19
+ break
20
+ inference(query)
21
+
22
+ if __name__ == '__main__':
23
+ conversation()
notebook/An/master/test/data_ingest.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import List
4
+
5
+ from ..rag_pipeline import get_embeddings, load_data
6
+ from ..utils import load_local, save_local
7
+
8
+ def main(args):
9
+ print(f"Log: {args}")
10
+
11
+ if args.clear_vectorstore:
12
+ import shutil
13
+ if os.path.isdir(args.vectorstore_dir):
14
+ shutil.rmtree(args.vectorstore_dir)
15
+
16
+ embed_model = get_embeddings(args.embed_model_name)
17
+ vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
18
+
19
+ new_docs = []
20
+ for data_path in args.data_paths:
21
+ new_docs.extend(load_data(data_path, args.file_type))
22
+ print(f"Got {len(new_docs)} documents.")
23
+
24
+ if args.chunk_method == "recursive":
25
+ from ..rag_pipeline import recursive_chunking
26
+ new_docs = recursive_chunking(new_docs, args.chunk_size, args.chunk_overlap)
27
+ elif args.chunk_method == "markdown":
28
+ from ..rag_pipeline import markdown_chunking
29
+ new_docs = markdown_chunking(new_docs, args.chunk_size, args.chunk_overlap)
30
+ print(f"Got {len(new_docs)} chunks.")
31
+
32
+ from langchain_community.vectorstores import FAISS
33
+ if vectorstore is None:
34
+ vectorstore = FAISS.from_documents(new_docs, embed_model)
35
+ docs = new_docs
36
+ print(f"Successfully consumed {len(new_docs)} documents.")
37
+ else:
38
+ docs.extend(new_docs)
39
+ vectorstore.add_documents(new_docs)
40
+
41
+ save_local(args.vectorstore_dir, vectorstore, docs)
42
+
43
+ import json
44
+ with open(os.path.join(args.vectorstore_dir, "config.json"), "a") as f:
45
+ json.dump(vars(args), f)
46
+
47
+
48
+ if __name__ == '__main__':
49
+ parser = argparse.ArgumentParser()
50
+
51
+ data_paths = [
52
+ 'dataset/RAG_Data/wiki_vi',
53
+ 'dataset/RAG_Data/youmed',
54
+ 'dataset/RAG_Data/mimic_ex_report',
55
+ 'dataset/RAG_Data/Download sach y/OCR',
56
+ ]
57
+
58
+ # Dataset params
59
+ parser.add_argument("--data_paths", type=List[str], required=False, default=data_paths)
60
+ parser.add_argument("--vectorstore_dir", type=str, required=False, default="notebook/An/master/knowledge/vectorstore_full")
61
+ parser.add_argument("--file_type", type=str, choices=["pdf", "txt"], default="txt")
62
+
63
+ # Model params
64
+ parser.add_argument("--embed_model_name", type=str, default="alibaba-nlp/gte-multilingual-base")
65
+
66
+ # Index params
67
+ parser.add_argument("--chunk_size", type=int, default=2048)
68
+ parser.add_argument("--chunk_overlap", type=int, default=512)
69
+ parser.add_argument("--chunk_method", type=str, choices=["recursive", "markdown"], default="markdown")
70
+
71
+ # Vectorstore params
72
+ parser.add_argument("--vectorstore", type=str, choices=["faiss", "chroma"], default="faiss")
73
+ parser.add_argument("--clear_vectorstore", action="store_true", default=True)
74
+
75
+
76
+ args = parser.parse_args()
77
+
78
+ main(args)
notebook/An/master/test/eval_lm.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from ..rag_pipeline import qa_prompt
3
+ from ..rag_pipeline import ChatAssistant
4
+ from ..utils import load_qa_dataset, load_prepared_retrieve_docs
5
+
6
+ from typing import List, Optional
7
+ from langchain.schema import Document
8
+
9
+ def get_answer_from_response(llm_response: str) -> str:
10
+ return llm_response.strip()
11
+
12
+ def build_qa_prompt(question: str, document: Optional[List[Document]]) -> str:
13
+ if document is not None:
14
+ document = '\n'.join([f"Document {i+1}:\n" + doc.page_content for i,doc in enumerate(document)])
15
+
16
+ return qa_prompt.format(question=question, document=document)
17
+
18
+ def process_question(question, prompt, answer, id, args, llm):
19
+ llm_response = llm.get_response("", prompt)
20
+ # ans = get_answer_from_response(llm_response)
21
+ with open("log.txt", "a", encoding="utf-8") as f:
22
+ f.write(f"ID: {id}\n")
23
+ f.write(prompt)
24
+ f.write(f"LLM Response:\n{llm_response}\n")
25
+ f.write(f"Answer: {answer} \n\n")
26
+
27
+ # with open("log_score.txt", "a", encoding="utf-8") as f:
28
+ # f.write("1" if ans == answer else "0")
29
+ # return 1 if ans == answer else 0
30
+ return llm_response
31
+
32
+ def evaluate_qa(questions, prompts, answers, ids, args, llm):
33
+ import concurrent.futures
34
+ from tqdm import tqdm
35
+ ans = []
36
+ with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
37
+ futures = [executor.submit(process_question, questions[i], prompts[i], answers[i], ids[i], args, llm) for i in range(len(questions))]
38
+ for future in tqdm(concurrent.futures.as_completed(futures), total=len(questions)):
39
+ ans.append(future.result())
40
+ return ans
41
+
42
+ def main(args):
43
+ ids, questions, options, answers = load_qa_dataset(args.qa_file)
44
+
45
+ if ids is None:
46
+ raise ValueError(f"No id field in {args.qa_file}.")
47
+
48
+ if args.num_docs > 0:
49
+ if args.prepared_retrieve_docs_path is not None:
50
+ documents = load_prepared_retrieve_docs(args.prepared_retrieve_docs_path)
51
+ docs = [d[:args.num_docs] for i,d in enumerate(documents)]
52
+ else:
53
+ raise ValueError(f"No prepared retrieve docs found.")
54
+ else:
55
+ docs = [None]*len(questions)
56
+
57
+ prompts = [build_qa_prompt(questions[i], docs[i]) for i in range(len(questions))]
58
+
59
+ llm = ChatAssistant(args.model_name, args.provider)
60
+
61
+ with open("log_score.txt", "a", encoding="utf-8") as f:
62
+ f.write("\n")
63
+
64
+ qa_results = evaluate_qa(questions, prompts, answers, ids, args, llm)
65
+ qa_results = [qa_results[i][qa_results[i].rfind("[")+1:qa_results[i].rfind("]")] for i in range(len(qa_results))]
66
+ # print(f"{qa_results}")
67
+ import pyperclip
68
+ pyperclip.copy('\n'.join(qa_results))
69
+
70
+ if __name__ == '__main__':
71
+ parser = argparse.ArgumentParser()
72
+
73
+ parser.add_argument("--qa_file", type=str, default="dataset/QA Data/random.jsonl")
74
+ parser.add_argument("--prepared_retrieve_docs_path", type=str, default="prepared_retrieve_docs.pkl")
75
+
76
+ parser.add_argument("--model_name", type=str, default="mistral-medium")
77
+ parser.add_argument("--provider", type=str, default="mistral")
78
+ parser.add_argument("--max_workers", type=int, default=4)
79
+ parser.add_argument("--num_docs", type=int, default=0)
80
+
81
+ parser.add_argument("--dataset_path", type=str)
82
+
83
+ args = parser.parse_args()
84
+
85
+ print(args)
86
+
87
+ main(args)
notebook/An/master/test/eval_qa.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from ..rag_pipeline import multichoice_qa_prompt
3
+ from ..rag_pipeline import ChatAssistant
4
+ from ..utils import paralelize, load_qa_dataset, load_prepared_retrieve_docs
5
+
6
+ from datetime import datetime
7
+ from typing import List, Optional
8
+ from langchain.schema import Document
9
+
10
+ def get_answer_from_response(llm_response: str) -> chr:
11
+ """
12
+ Get the answer from the LLM response.
13
+ """
14
+ return llm_response[llm_response.lower().rfind("the answer is ") + 14]
15
+
16
+ def build_multichoice_qa_prompt(question: str, options: str, document: Optional[List[Document]]) -> str:
17
+ """
18
+ Build the prompt for the multichoice QA task.
19
+ """
20
+ if document is not None:
21
+ document = '\n'.join([f"Document {i+1}:\n" + doc.page_content for i,doc in enumerate(document)])
22
+
23
+ return multichoice_qa_prompt.format(question=question, options=options, document=document)
24
+
25
+ def process_question(question, prompt, answer, id, args, llm):
26
+ llm_response = ""
27
+ for j in range(args.retries):
28
+ try:
29
+ llm_response = llm.get_response("", prompt)
30
+ ans = get_answer_from_response(llm_response)
31
+ if ans in ["A", "B", "C", "D", "E"]:
32
+ with open("log.txt", "a", encoding="utf-8") as f:
33
+ f.write(f"ID: {id}\n")
34
+ f.write(prompt)
35
+ f.write(f"LLM Response:\n{llm_response}\n")
36
+ f.write(f"Answer: {answer} {ans}\n\n")
37
+ break
38
+ except Exception as e:
39
+ print(f"Error: {e}")
40
+ ans = "#"
41
+ with open("log_score.txt", "a", encoding="utf-8") as f:
42
+ f.write("1" if ans == answer else "0")
43
+ return 1 if ans == answer else 0
44
+
45
+ def evaluate_qa(questions, prompts, answers, ids, args, llm):
46
+ import concurrent.futures
47
+ from tqdm import tqdm
48
+ correct = 0
49
+ with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
50
+ futures = [executor.submit(process_question, questions[i], prompts[i], answers[i], ids[i], args, llm) for i in range(len(questions))]
51
+ for future in tqdm(concurrent.futures.as_completed(futures), total=len(questions)):
52
+ correct += future.result()
53
+ return correct / len(questions)
54
+
55
+
56
+ def main(args):
57
+ ids, questions, options, answers = load_qa_dataset(args.qa_file)
58
+
59
+ if ids is None:
60
+ raise ValueError(f"No id field in {args.qa_file}.")
61
+
62
+ if args.num_docs > 0:
63
+ if args.prepared_retrieve_docs_path is not None:
64
+ documents = load_prepared_retrieve_docs(args.prepared_retrieve_docs_path)
65
+ docs = [d[:args.num_docs] for i,d in enumerate(documents)]
66
+ else:
67
+ raise ValueError(f"No prepared retrieve docs found.")
68
+ else:
69
+ docs = [None]*len(questions)
70
+
71
+ prompts = [build_multichoice_qa_prompt(questions[i], options[i], docs[i]) for i in range(len(questions))]
72
+
73
+ # print(prompts[0])
74
+ llm = ChatAssistant(args.model_name, args.provider)
75
+
76
+ with open("log_score.txt", "a", encoding="utf-8") as f:
77
+ f.write(f"\n{datetime.now()} {args}\n")
78
+
79
+ acc = evaluate_qa(questions, prompts, answers, ids, args, llm)
80
+ print(f"Accuracy: {acc}")
81
+
82
+
83
+ if __name__ == '__main__':
84
+ parser = argparse.ArgumentParser()
85
+
86
+ parser.add_argument("--qa_file", type=str, default="dataset/QA Data/MedAB/MedABv2.jsonl")
87
+ parser.add_argument("--prepared_retrieve_docs_path", type=str, default="dataset/QA Data/MedAB/prepared_retrieve_docs_full.pkl")
88
+
89
+ # parser.add_argument("--qa_file", type=str, default="dataset/QA Data/MedMCQA/translated_hard_questions.jsonl")
90
+ # parser.add_argument("--prepared_retrieve_docs_path", type=str, default="dataset/QA Data/MedMCQA/prepared_retrieve_docs_full.pkl")
91
+
92
+ # Eval params
93
+ parser.add_argument("--model_name", type=str, default="gemma3:12b")
94
+ parser.add_argument("--provider", type=str, default="ollama")
95
+ parser.add_argument("--max_workers", type=int, default=1)
96
+ parser.add_argument("--num_docs", type=int, default=0)
97
+ parser.add_argument("--retries", type=int, default=4)
98
+
99
+
100
+ # Dataset params
101
+ parser.add_argument("--dataset_path", type=str)
102
+
103
+ args = parser.parse_args()
104
+ print(f"Log:{args}")
105
+
106
+ main(args)
notebook/An/master/test/prepare_retrieve.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ from ..rag_pipeline import get_embeddings, vretrieve
5
+ from ..utils import load_local, load_qa_dataset, safe_save_langchain_docs
6
+
7
+ def main(args):
8
+ embed_model = get_embeddings(args.embed_model_name, show_progress=False)
9
+ vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
10
+
11
+ ids, questions, options, answers = load_qa_dataset(args.qa_data_path)
12
+
13
+ rag_queries = [f"Question: {questions[i]}\n{options[i]}" for i in range(len(questions))]
14
+ if (args.rag_queries_path is not None) and os.path.exists(args.rag_queries_path):
15
+ import json
16
+ with open(args.rag_queries_path, "r", encoding="utf-8") as f:
17
+ rag_queries = [json.loads(line)["query"] for line in f]
18
+
19
+ from tqdm import tqdm
20
+ retrieve_results = [vretrieve(rag_queries[i], vectorstore, docs, args.retriever_k, args.metric, args.threshold) for i in tqdm(range(len(rag_queries)), desc="Retrieving documents")]
21
+
22
+ safe_save_langchain_docs(retrieve_results, args.prepared_retrieve_docs_path)
23
+
24
+
25
+ if __name__ == '__main__':
26
+ parser = argparse.ArgumentParser()
27
+
28
+ # Dataset params
29
+ parser.add_argument("--qa_data_path", type=str, default="dataset/QA Data/MedMCQA/translated_hard_questions.jsonl")
30
+
31
+ # Vectorstore params
32
+ parser.add_argument("--vectorstore_dir", type=str, default="notebook/An/master/knowledge/vectorstore_full")
33
+ parser.add_argument("--prepared_retrieve_docs_path", type=str, default="dataset/QA Data/MedMCQA/prepared_retrieve_docs_full.pkl")
34
+ parser.add_argument("--rag_queries_path", type=str, default=None)
35
+
36
+ # Model params
37
+ parser.add_argument("--embed_model_name", type=str, default="alibaba-nlp/gte-multilingual-base")
38
+
39
+ # Vectorstore retriever params
40
+ parser.add_argument("--vectorstore", type=str, choices=["faiss", "chroma"], default="faiss")
41
+ parser.add_argument("--metric", type=str, choices=["cosine", "mmr", "bm25"], default="mmr")
42
+ parser.add_argument("--retriever_k", type=int, default=20, help="Number of documents to retrieve")
43
+ parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for cosine similarity")
44
+ parser.add_argument("--reranker_model_name", type=str, default=None)
45
+ parser.add_argument("--reranker_k", type=int, default=50, help="Number of documents to rerank")
46
+
47
+ args = parser.parse_args()
48
+ print(args)
49
+
50
+ main(args)
notebook/An/master/test/test_llm.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from ..rag_pipeline import ChatAssistant
2
+ from ..rag_pipeline import request_retrieve_prompt
3
+
4
+ cb = ChatAssistant("mistral-medium", "mistral")
5
+
6
+ query = "Beta blocker for hypertension"
7
+ query = request_retrieve_prompt.format(conversation=query, role="customer")
8
+ response = cb.get_response(user=query)
9
+ print(response)
notebook/An/master/test/test_retrieve.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ from ..rag_pipeline import get_embeddings, rerank
5
+ from ..utils import load_local
6
+
7
+ from ..rag_pipeline import vretrieve
8
+
9
+ def main(args):
10
+ embed_model = get_embeddings(args.embed_model_name)
11
+ vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
12
+ retrieve_results = vretrieve(args.query, vectorstore, docs, args.retriever_k, args.metric, args.threshold)
13
+
14
+ retrieve_results = rerank(retrieve_results)
15
+
16
+ print(retrieve_results)
17
+
18
+ if __name__ == '__main__':
19
+ parser = argparse.ArgumentParser()
20
+
21
+ parser.add_argument("--query", type=str, required=False, default="What are the applications of beta blockers in the treatment of hypertension?")
22
+
23
+ # Vectorstore params
24
+ parser.add_argument("--vectorstore_dir", type=str, required=False, default="notebook/An/master/knowledge/vectorstore_full")
25
+
26
+ # Model params
27
+ parser.add_argument("--embed_model_name", type=str, default="alibaba-nlp/gte-multilingual-base")
28
+
29
+ # Vectorstore retriever params
30
+ parser.add_argument("--vectorstore", type=str, choices=["faiss", "chroma"], default="faiss")
31
+ parser.add_argument("--metric", type=str, choices=["cosine", "mmr", "bm25"], default="cosine")
32
+ parser.add_argument("--retriever_k", type=int, default=4, help="Number of documents to retrieve")
33
+ parser.add_argument("--threshold", type=float, default=0.7, help="Threshold for cosine similarity")
34
+ parser.add_argument("--reranker_model_name", type=str, default=None)
35
+ parser.add_argument("--reranker_k", type=int, default=20, help="Number of documents to rerank")
36
+
37
+ args = parser.parse_args()
38
+
39
+ main(args)