LiamKhoaLe commited on
Commit
80cb919
·
0 Parent(s):
Files changed (21) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +35 -0
  3. .gitignore +4 -0
  4. DATA_PROCESSING.md +250 -0
  5. Dockerfile +31 -0
  6. LICENSE.txt +201 -0
  7. README.md +32 -0
  8. REQUEST.md +156 -0
  9. app.py +423 -0
  10. mount_drive.py +9 -0
  11. requirements.txt +13 -0
  12. utils/ __init__.py +22 -0
  13. utils/.DS_Store +0 -0
  14. utils/augment.py +105 -0
  15. utils/datasets.py +66 -0
  16. utils/drive_saver.py +88 -0
  17. utils/llm.py +186 -0
  18. utils/processor.py +411 -0
  19. utils/rag.py +345 -0
  20. utils/schema.py +68 -0
  21. utils/token.py +107 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .env
2
+ client1.json
3
+ client2.json
4
+ medai.json
DATA_PROCESSING.md ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📊 MedAI Data Processing Techniques
2
+
3
+ This document comprehensively outlines all the data processing techniques implemented in the MedAI Processing project for augmenting and centrally processing medical datasets for LLM fine-tuning.
4
+
5
+ ## 🎯 Project Overview
6
+
7
+ The MedAI Processing system transforms raw medical datasets into a **centralized fine-tuning format** (JSONL + CSV) with comprehensive data augmentation capabilities. The system processes multiple medical dataset types and applies various enhancement techniques to improve data quality and diversity.
8
+
9
+ ## 🏗️ System Architecture
10
+
11
+ ### Core Components
12
+ - **FastAPI Web Service**: RESTful API for dataset processing
13
+ - **Multi-LLM Rotator**: NVIDIA API + Google Gemini integration
14
+ - **Centralized Writer**: Parallel JSONL + CSV output generation
15
+ - **Google Drive Integration**: Automated artifact storage
16
+ - **Progress Monitoring**: Real-time job status tracking
17
+
18
+ ### Supported Datasets
19
+ 1. **HealthCareMagic** (100k medical dialogues)
20
+ 2. **iCliniq** (10k medical consultations)
21
+ 3. **PubMedQA-Labelled** (biomedical Q&A with answers)
22
+ 4. **PubMedQA-Unlabelled** (biomedical Q&A without answers)
23
+ 5. **PubMedQA-Map** (biomedical Q&A mapping format)
24
+
25
+ ## 🔧 Data Processing Pipeline
26
+
27
+ ### 1. Data Ingestion & Download
28
+ - **Hugging Face Hub Integration**: Automatic dataset downloading
29
+ - **Format Detection**: JSON/JSONL auto-detection and parsing
30
+ - **Caching System**: Local storage with symlink optimization
31
+
32
+ ### 2. Data Cleaning & Preprocessing
33
+
34
+ #### Text Normalization
35
+ - **Unicode Fixing**: `ftfy` library for text encoding issues
36
+ - **Whitespace Standardization**: Consistent spacing and line breaks
37
+ - **Quote Canonicalization**: Standard quote character conversion
38
+ - **Terminal Punctuation**: Ensures proper sentence endings
39
+
40
+ #### Content Sanitization
41
+ - **Length Capping**: Configurable maximum character limits (default: 5000)
42
+ - **Language Detection**: English language validation using `langid`
43
+ - **Content Truncation**: Smart sentence boundary cutting for long texts
44
+
45
+ ### 3. Data Augmentation Techniques
46
+
47
+ #### LLM-Based Paraphrasing
48
+ - **Multi-Model Rotation**: NVIDIA API (primary) + Gemini (fallback)
49
+ - **Difficulty Levels**: Easy vs. Hard paraphrasing modes
50
+ - **Medical Context Preservation**: Maintains clinical terminology accuracy
51
+ - **Configurable Ratios**: User-defined augmentation percentages (0.0-1.0)
52
+
53
+ #### Back-Translation Augmentation
54
+ - **Multi-Language Support**: German as intermediate language
55
+ - **Meaning Preservation**: Maintains semantic accuracy through translation cycles
56
+ - **Fallback Mechanisms**: Automatic retry with alternative models
57
+ - **Quality Control**: Length and content validation
58
+
59
+ #### Style Standardization
60
+ - **Clinical Voice Enforcement**: Neutral, professional medical tone
61
+ - **Absolute Language Removal**: Replaces guarantees with probabilistic language
62
+ - **Forum Sign-off Removal**: Eliminates informal communication patterns
63
+ - **Consistent Punctuation**: Standardized sentence structure
64
+
65
+ ### 4. Data Quality Assurance
66
+
67
+ #### De-identification (PHI Removal)
68
+ - **Email Redaction**: `[REDACTED_EMAIL]` placeholder
69
+ - **Phone Number Masking**: `[REDACTED_PHONE]` placeholder
70
+ - **URL/IP Address Removal**: `[REDACTED_URL]` and `[REDACTED_IP]` placeholders
71
+ - **Configurable Privacy**: Optional PHI removal per dataset
72
+
73
+ #### Deduplication
74
+ - **Fingerprinting Algorithm**: MD5-based content hashing
75
+ - **Multi-Field Matching**: Instruction + Input + Output combination
76
+ - **Normalized Comparison**: Case-insensitive, whitespace-normalized matching
77
+ - **Performance Optimized**: In-memory set-based deduplication
78
+
79
+ #### Consistency Validation
80
+ - **LLM-Based QA Check**: Automated answer validation against context
81
+ - **Configurable Sampling**: Ratio-based consistency checking (e.g., 0.01)
82
+ - **Medical Safety Validation**: Ensures clinical accuracy and safety
83
+ - **Failure Tagging**: Marks samples with consistency issues
84
+
85
+ ### 5. Advanced Augmentation Features
86
+
87
+ #### Knowledge Distillation
88
+ - **Pseudo-Label Generation**: Creates labels for unlabeled data
89
+ - **Fractional Processing**: Configurable percentage for distillation
90
+ - **Single-Prompt Approach**: Efficient single LLM call per sample
91
+ - **Length Control**: Maintains reasonable output lengths
92
+
93
+ #### Multi-Variant Generation
94
+ - **Configurable Counts**: 1-3 augmented variants per sample
95
+ - **Tagged Augmentations**: Tracks applied augmentation techniques
96
+ - **Original Preservation**: Always maintains base sample
97
+ - **Randomized IDs**: Unique identifiers for augmented variants
98
+
99
+ ### 6. Output Generation & Storage
100
+
101
+ #### Centralized Format
102
+ - **SFT Schema**: Standardized Supervised Fine-Tuning format
103
+ - **Metadata Preservation**: Source, task type, and augmentation tags
104
+ - **Dual Output**: Simultaneous JSONL and CSV generation
105
+ - **Memory-Safe Streaming**: Handles large datasets efficiently
106
+
107
+ #### Storage Integration
108
+ - **Local Caching**: `cache/outputs/` directory storage
109
+ - **Google Drive Upload**: Automated cloud storage integration
110
+ - **Timestamped Naming**: Unique file identification
111
+ - **MIME Type Handling**: Proper content type specification
112
+
113
+ ## ⚙️ Configuration Options
114
+
115
+ ### Augmentation Parameters
116
+ ```python
117
+ class AugmentOptions:
118
+ paraphrase_ratio: float = 0.0 # 0.0-1.0
119
+ paraphrase_outputs: bool = False # Augment model answers
120
+ backtranslate_ratio: float = 0.0 # 0.0-1.0
121
+ style_standardize: bool = True # Enforce clinical style
122
+ deidentify: bool = True # Remove PHI
123
+ dedupe: bool = True # Remove duplicates
124
+ max_chars: int = 5000 # Text length limit
125
+ consistency_check_ratio: float = 0.0 # 0.0-1.0
126
+ distill_fraction: float = 0.0 # 0.0-1.0 for unlabeled
127
+ expand: bool = True # Enable augmentation
128
+ max_aug_per_sample: int = 2 # 1-3 variants
129
+ ```
130
+
131
+ ### Processing Parameters
132
+ ```python
133
+ class ProcessParams:
134
+ augment: AugmentOptions # Augmentation settings
135
+ sample_limit: Optional[int] = None # Dataset sampling
136
+ seed: int = 42 # Reproducibility
137
+ ```
138
+
139
+ ## 📈 Performance & Monitoring
140
+
141
+ ### Progress Tracking
142
+ - **Real-time Updates**: Live progress percentage and status messages
143
+ - **Background Processing**: Non-blocking job execution
144
+ - **State Management**: Thread-safe status tracking
145
+ - **Error Handling**: Comprehensive exception logging
146
+
147
+ ### Resource Management
148
+ - **API Key Rotation**: Automatic fallback between multiple API keys
149
+ - **Rate Limiting**: Configurable request throttling
150
+ - **Memory Optimization**: Streaming processing for large datasets
151
+ - **Concurrent Processing**: Background task execution
152
+
153
+ ## 🔒 Security & Privacy
154
+
155
+ ### Data Protection
156
+ - **PHI Removal**: Automatic sensitive information redaction
157
+ - **Secure Storage**: Google Drive integration with OAuth2
158
+ - **Access Control**: Environment-based API key management
159
+ - **Audit Logging**: Comprehensive processing logs
160
+
161
+ ### API Security
162
+ - **OAuth2 Integration**: Google Drive authentication
163
+ - **Token Management**: Secure credential handling
164
+ - **Request Validation**: Pydantic model validation
165
+ - **Error Sanitization**: Safe error message handling
166
+
167
+ ## 🚀 Usage Examples
168
+
169
+ ### Basic Processing
170
+ ```bash
171
+ # Process HealthCareMagic with default settings
172
+ curl -X POST \
173
+ -H "Content-Type: application/json" \
174
+ -d '{"augment": {"paraphrase_ratio": 0.1}}' \
175
+ https://binkhoale1812-medai-processing.hf.space/process/healthcaremagic
176
+ ```
177
+
178
+ ### Advanced Augmentation
179
+ ```bash
180
+ # Process with comprehensive augmentation
181
+ curl -X POST \
182
+ -H "Content-Type: application/json" \
183
+ -d '{
184
+ "augment": {
185
+ "paraphrase_ratio": 0.2,
186
+ "backtranslate_ratio": 0.1,
187
+ "paraphrase_outputs": true,
188
+ "style_standardize": true,
189
+ "deidentify": true,
190
+ "dedupe": true,
191
+ "max_chars": 5000,
192
+ "consistency_check_ratio": 0.01,
193
+ "max_aug_per_sample": 3
194
+ },
195
+ "sample_limit": 1000,
196
+ "seed": 42
197
+ }' \
198
+ https://binkhoale1812-medai-processing.hf.space/process/icliniq
199
+ ```
200
+
201
+ ## 📊 Output Statistics
202
+
203
+ ### Processing Metrics
204
+ - **Written Rows**: Total processed samples
205
+ - **Paraphrased Inputs**: Count of augmented user inputs
206
+ - **Paraphrased Outputs**: Count of augmented model responses
207
+ - **Back-translated**: Count of translation-augmented samples
208
+ - **Deduplication**: Count of skipped duplicate samples
209
+ - **Consistency Failures**: Count of validation failures
210
+
211
+ ### File Outputs
212
+ - **JSONL Format**: Structured fine-tuning data with metadata
213
+ - **CSV Format**: Simplified tabular representation
214
+ - **Google Drive**: Cloud storage with automatic upload
215
+ - **Local Cache**: Persistent local storage
216
+
217
+ ## 🔮 Future Enhancements
218
+
219
+ ### Planned Features
220
+ - **Additional Dataset Support**: More medical dataset types
221
+ - **Advanced Augmentation**: More sophisticated LLM techniques
222
+ - **Quality Metrics**: Automated data quality scoring
223
+ - **Batch Processing**: Multiple dataset concurrent processing
224
+ - **Custom Schemas**: User-defined output formats
225
+
226
+ ### Scalability Improvements
227
+ - **Distributed Processing**: Multi-node processing support
228
+ - **Streaming Augmentation**: Real-time data enhancement
229
+ - **Caching Optimization**: Improved performance and cost efficiency
230
+ - **API Rate Limiting**: Better resource management
231
+
232
+ ## 📚 Technical Dependencies
233
+
234
+ ### Core Libraries
235
+ - **FastAPI**: Web framework for API development
236
+ - **Hugging Face Hub**: Dataset downloading and management
237
+ - **Google GenAI**: Gemini model integration
238
+ - **ftfy**: Text encoding and normalization
239
+ - **langid**: Language detection
240
+ - **orjson**: High-performance JSON processing
241
+
242
+ ### External Services
243
+ - **NVIDIA API**: Primary LLM service for paraphrasing
244
+ - **Google Gemini**: Fallback LLM service
245
+ - **Google Drive**: Cloud storage integration
246
+ - **Hugging Face Spaces**: Deployment platform
247
+
248
+ ---
249
+
250
+ *This document provides a comprehensive overview of all data processing techniques implemented in the MedAI Processing project. For specific implementation details, refer to the individual module files in the `utils/` directory.*
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Install system dependencies as root (no sudo!)
4
+ RUN apt-get update && apt-get install -y --no-install-recommends \
5
+ ca-certificates curl && rm -rf /var/lib/apt/lists/*
6
+
7
+ # Create non-root user
8
+ RUN useradd -m -u 1000 user
9
+ ENV HOME=/home/user
10
+ WORKDIR $HOME/app
11
+
12
+ # Install Python dependencies first (better layer caching)
13
+ COPY --chown=user requirements.txt .
14
+ RUN pip install --upgrade pip && pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Copy the application
17
+ COPY --chown=user . .
18
+
19
+ # Hugging Face cache setup
20
+ ENV HF_HOME="$HOME/.cache/huggingface"
21
+ ENV SENTENCE_TRANSFORMERS_HOME="$HOME/.cache/huggingface/sentence-transformers"
22
+ ENV MEDGEMMA_HOME="$HOME/.cache/huggingface/sentence-transformers"
23
+
24
+ # Prepare runtime dirs
25
+ RUN mkdir -p $HOME/app/logs $HOME/app/cache $HOME/app/cache/hf $HOME/app/cache/outputs && \
26
+ chown -R user:user $HOME/app
27
+
28
+ USER user
29
+
30
+ EXPOSE 7860
31
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 Dang Khoa Le
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MedAI Processing
3
+ emoji: ⚕️
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ short_description: Process and centralise medical doc for llm finetuning
10
+ ---
11
+
12
+ ## Quick Access:
13
+
14
+ [HF Space](https://huggingface.co/spaces/BinKhoaLe1812/MedAI_Processing)
15
+
16
+ [MedDialog-100k](https://huggingface.co/datasets/BinKhoaLe1812/MedDialog-EN-100k)
17
+
18
+ [MedDialog-100k](https://huggingface.co/datasets/BinKhoaLe1812/MedDialog-EN-10k)
19
+
20
+ [PubMedQA-Labelled](https://huggingface.co/datasets/BinKhoaLe1812/PubMedQA-L)
21
+
22
+ [PubMedQA-Unlabelled](https://huggingface.co/datasets/BinKhoaLe1812/PubMedQA-U)
23
+
24
+ [PubMedQA-Mapper](https://huggingface.co/datasets/BinKhoaLe1812/PubMedQA-MAP)
25
+
26
+
27
+ ## CURL Request Instruction
28
+ [Request Doc](https://huggingface.co/spaces/MedAI-COS30018/MedAI_Processing/blob/main/REQUEST.md)
29
+
30
+ ## License
31
+ [Apache-2.0 LICENSE](https://huggingface.co/spaces/BinKhoaLe1812/MedAI_Processing/blob/main/LICENSE.txt)
32
+
REQUEST.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📑 MedAI Processing – Request Examples
2
+
3
+ Base URL of the Space:
4
+ **`https://binkhoale1812-medai-processing.hf.space`**
5
+
6
+ This Space processes medical datasets into a centralised fine-tuning format (JSONL + CSV) with optional augmentations such as **paraphrasing**, **back-translation**, **style standardisation**, **de-identification**, and **deduplication**.
7
+
8
+ ---
9
+
10
+ ## 🔹 1. Process HealthCareMagic
11
+
12
+ ```bash
13
+ curl -X POST \
14
+ -H "Content-Type: application/json" \
15
+ -d '{
16
+ "augment": {
17
+ "paraphrase_ratio": 0.1,
18
+ "backtranslate_ratio": 0.05,
19
+ "paraphrase_outputs": false,
20
+ "style_standardize": true,
21
+ "deidentify": true,
22
+ "dedupe": true,
23
+ "max_chars": 5000
24
+ },
25
+ "sample_limit": 2000,
26
+ "seed": 42
27
+ }' \
28
+ https://binkhoale1812-medai-processing.hf.space/process/healthcaremagic
29
+ ````
30
+
31
+ ---
32
+
33
+ ## 🔹 2. Process iCliniq
34
+
35
+ ```bash
36
+ curl -X POST \
37
+ -H "Content-Type: application/json" \
38
+ -d '{
39
+ "augment": {
40
+ "paraphrase_ratio": 0.2,
41
+ "backtranslate_ratio": 0.1,
42
+ "paraphrase_outputs": true,
43
+ "style_standardize": true,
44
+ "deidentify": true,
45
+ "dedupe": true,
46
+ "max_chars": 5000
47
+ },
48
+ "sample_limit": 1500,
49
+ "seed": 123
50
+ }' \
51
+ https://binkhoale1812-medai-processing.hf.space/process/icliniq
52
+ ```
53
+
54
+ ---
55
+
56
+ ## 🔹 3. Process PubMedQA (Labelled)
57
+
58
+ ```bash
59
+ curl -X POST \
60
+ -H "Content-Type: application/json" \
61
+ -d '{
62
+ "augment": {
63
+ "paraphrase_ratio": 0.05,
64
+ "backtranslate_ratio": 0.02,
65
+ "paraphrase_outputs": false,
66
+ "style_standardize": true,
67
+ "deidentify": false,
68
+ "dedupe": true,
69
+ "max_chars": 8000
70
+ },
71
+ "sample_limit": 1000,
72
+ "seed": 99
73
+ }' \
74
+ https://binkhoale1812-medai-processing.hf.space/process/pubmedqa_l
75
+ ```
76
+
77
+ ---
78
+
79
+ ## 🔹 4. Process PubMedQA (Unlabelled)
80
+
81
+ ```bash
82
+ curl -X POST \
83
+ -H "Content-Type: application/json" \
84
+ -d '{
85
+ "augment": {
86
+ "paraphrase_ratio": 0.05,
87
+ "backtranslate_ratio": 0.05,
88
+ "paraphrase_outputs": false,
89
+ "style_standardize": true,
90
+ "deidentify": true,
91
+ "dedupe": true,
92
+ "max_chars": 7000,
93
+ "consistency_check_ratio": 0.01,
94
+ "distill_fraction": 0.1
95
+ },
96
+ "sample_limit": 500,
97
+ "seed": 7
98
+ }' \
99
+ https://binkhoale1812-medai-processing.hf.space/process/pubmedqa_u
100
+ ```
101
+
102
+ ---
103
+
104
+ ## 🔹 5. Process PubMedQA (Map)
105
+
106
+ ```bash
107
+ curl -X POST \
108
+ -H "Content-Type: application/json" \
109
+ -d '{
110
+ "augment": {
111
+ "paraphrase_ratio": 0.1,
112
+ "backtranslate_ratio": 0.05,
113
+ "paraphrase_outputs": true,
114
+ "style_standardize": true,
115
+ "deidentify": true,
116
+ "dedupe": true,
117
+ "max_chars": 6000
118
+ },
119
+ "sample_limit": 1200,
120
+ "seed": 2024
121
+ }' \
122
+ https://binkhoale1812-medai-processing.hf.space/process/pubmedqa_map
123
+ ```
124
+
125
+ ---
126
+
127
+ ## 🔹 6. Check Current Job Status
128
+
129
+ ```bash
130
+ curl https://binkhoale1812-medai-processing.hf.space/status
131
+ ```
132
+
133
+ ---
134
+
135
+ ## 🔹 7. List Generated Artifacts
136
+
137
+ ```bash
138
+ curl https://binkhoale1812-medai-processing.hf.space/files
139
+ ```
140
+
141
+ ---
142
+
143
+ # ✅ Notes
144
+
145
+ * Each run outputs both `.jsonl` and `.csv` in `cache/outputs/` and also uploads them to Google Drive folder ID:
146
+ `1JvW7its63E58fLxurH8ZdhxzdpcMrMbt`
147
+ * `augment` options can be adjusted per dataset:
148
+
149
+ * `paraphrase_ratio` – % of rows paraphrased (0–1)
150
+ * `backtranslate_ratio` – % of rows back-translated
151
+ * `paraphrase_outputs` – whether to also augment model answers
152
+ * `style_standardize` – enforce neutral, clinical style
153
+ * `deidentify` – redact PHI (emails, phones, URLs, IPs)
154
+ * `dedupe` – skip duplicate pairs
155
+ * `consistency_check_ratio` – run lightweight QA sanity check
156
+ * `distill_fraction` – generate pseudo-labels for unlabelled data
app.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Root FastAPI
2
+ import os
3
+ import json
4
+ import time, logging
5
+ import threading
6
+ import datetime as dt
7
+ from typing import Optional, Dict
8
+
9
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
10
+ from fastapi.responses import HTMLResponse, JSONResponse
11
+ from pydantic import BaseModel
12
+ from dotenv import load_dotenv
13
+
14
+ from utils.datasets import resolve_dataset, hf_download_dataset
15
+ from utils.processor import process_file_into_sft
16
+ from utils.rag import process_file_into_rag
17
+ from utils.drive_saver import DriveSaver
18
+ from utils.llm import Paraphraser
19
+ from utils.schema import CentralisedWriter
20
+ from utils.token import get_credentials, exchange_code, build_auth_url
21
+
22
+ # ────────── Log ───────────
23
+ logger = logging.getLogger("app")
24
+ if not logger.handlers:
25
+ logger.setLevel(logging.INFO)
26
+ handler = logging.StreamHandler()
27
+ logger.addHandler(handler)
28
+
29
+ # ────────── Boot ──────────
30
+ load_dotenv(override=True)
31
+
32
+ SPACE_NAME = os.getenv("SPACE_NAME", "MedAI Processor")
33
+ OUTPUT_DIR = os.path.abspath(os.getenv("OUTPUT_DIR", "cache/outputs"))
34
+ LOG_DIR = os.path.abspath(os.getenv("LOG_DIR", "logs"))
35
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
36
+ os.makedirs(LOG_DIR, exist_ok=True)
37
+
38
+ # --- Bootstrap Google OAuth ---
39
+ try:
40
+ creds = get_credentials()
41
+ if creds:
42
+ logger.info("✅ OAuth credentials loaded and valid")
43
+ except Exception as e:
44
+ logger.warning(f"⚠️ OAuth not initialized yet: {e}")
45
+
46
+ # --- Bootstrap Google Drive ---
47
+ drive = DriveSaver(default_folder_id=os.getenv("GDRIVE_FOLDER_ID"))
48
+
49
+ # LLM rotator with paraphraser nodes
50
+ paraphraser = Paraphraser(
51
+ nvidia_model=os.getenv("NVIDIA_MODEL", "meta/llama-3.1-8b-instruct"),
52
+ gemini_model_easy=os.getenv("GEMINI_MODEL_EASY", "gemini-2.5-flash-lite"),
53
+ gemini_model_hard=os.getenv("GEMINI_MODEL_HARD", "gemini-2.5-flash"),
54
+ )
55
+
56
+ app = FastAPI(title="Medical Dataset Augmenter", version="1.1.0")
57
+
58
+ STATE_LOCK = threading.Lock()
59
+ STATE: Dict[str, object] = {
60
+ "running": False,
61
+ "dataset": None,
62
+ "started_at": None,
63
+ "progress": 0.0,
64
+ "message": "idle",
65
+ "last_result": None
66
+ }
67
+
68
+ class AugmentOptions(BaseModel):
69
+ # ratios are 0..1
70
+ paraphrase_ratio: float = 0.0
71
+ paraphrase_outputs: bool = False
72
+ backtranslate_ratio: float = 0.0
73
+ style_standardize: bool = True
74
+ deidentify: bool = True
75
+ dedupe: bool = True
76
+ max_chars: int = 5000 # cap extremely long contexts
77
+ consistency_check_ratio: float = 0.0 # small ratio e.g. 0.01
78
+ # KD / distillation (optional, keeps default off)
79
+ distill_fraction: float = 0.0 # for unlabeled only
80
+ expand: bool = True # Enable back-translation and complex augmentation
81
+ max_aug_per_sample: int = 2 # Between 1-3, number of LLM call to augment/paraphrase data
82
+
83
+ class ProcessParams(BaseModel):
84
+ augment: AugmentOptions = AugmentOptions()
85
+ sample_limit: Optional[int] = None # Set data sampling if needed
86
+ seed: int = 42
87
+ rag_processing: bool = False # Enable RAG-specific processing
88
+
89
+ def set_state(**kwargs):
90
+ with STATE_LOCK:
91
+ STATE.update(kwargs)
92
+
93
+ def now_iso():
94
+ return dt.datetime.utcnow().isoformat()
95
+
96
+ # Instructional UI
97
+ @app.get("/", response_class=HTMLResponse)
98
+ def root():
99
+ return f"""
100
+ <html>
101
+ <head>
102
+ <title>{SPACE_NAME} – Medical Dataset Augmenter</title>
103
+ <style>
104
+ body {{ font-family: Arial, sans-serif; max-width: 900px; margin: 2rem auto; line-height: 1.5; }}
105
+ h1, h2 {{ color: #2c3e50; }}
106
+ button {{
107
+ background: #2d89ef; color: white; border: none; padding: 8px 16px;
108
+ border-radius: 5px; cursor: pointer; margin: 5px 0;
109
+ }}
110
+ button:hover {{ background: #1b5dab; }}
111
+ .section {{ margin-bottom: 2rem; }}
112
+ #log {{ background:#f5f5f5; padding:10px; border-radius:6px; margin-top:10px; font-size:0.9rem; }}
113
+ a {{ color:#2d89ef; text-decoration:none; }}
114
+ a:hover {{ text-decoration:underline; }}
115
+ </style>
116
+ </head>
117
+ <body>
118
+ <h1>📊 {SPACE_NAME} – Medical Dataset Augmenter</h1>
119
+ <p>This Hugging Face Space processes medical datasets into a <b>centralised fine-tuning format</b>
120
+ (JSONL + CSV), with optional <i>data augmentation</i>.</p>
121
+
122
+ <div class="section">
123
+ <h2>⚡ Quick Actions</h2>
124
+ <p>Click a button below to start processing a dataset with default augmentation parameters.</p>
125
+ <button onclick="startJob('healthcaremagic')">▶ProcAugment HealthCareMagic (100k)</button><br>
126
+ <button onclick="startJob('icliniq')">▶ProcAugment iCliniq (10k-derived)</button><br>
127
+ <button onclick="startJob('pubmedqa_l')">▶ProcAugment PubMedQA (Labelled)</button><br>
128
+ <button onclick="startJob('pubmedqa_u')">▶ProcAugment PubMedQA (Unlabelled)</button><br>
129
+ <button onclick="startJob('pubmedqa_map')">▶ProcAugment PubMedQA (Map)</button><br><br>
130
+ <div style="border-top: 1px solid #ddd; padding-top: 10px; margin-top: 10px;">
131
+ <strong>RAG Processing:</strong> - Convert to QCA format for RAG systems<br>
132
+ <button onclick="startRagJob('healthcaremagic')" style="background: #e74c3c;">▶ RAG HealthCareMagic (100k)</button><br>
133
+ <button onclick="startRagJob('icliniq')" style="background: #e74c3c;">▶ RAG iCliniq (10k-derived)</button><br>
134
+ <button onclick="startRagJob('pubmedqa_u')" style="background: #e74c3c;">▶ RAG PubMedQA (Unlabelled)</button><br>
135
+ <button onclick="startRagJob('pubmedqa_l')" style="background: #e74c3c;">▶ RAG PubMedQA (Labelled)</button><br>
136
+ <button onclick="startRagJob('pubmedqa_map')" style="background: #e74c3c;">▶ RAG PubMedQA (Map)</button>
137
+ </div>
138
+ </div>
139
+
140
+ <div class="section">
141
+ <h2>📂 Monitoring</h2>
142
+ <ul>
143
+ <li><a href="/status" target="_blank">Check current job status</a></li>
144
+ <li><a href="/files" target="_blank">List generated artifacts</a></li>
145
+ <li><a href="https://binkhoale1812-medai-processing.hf.space/oauth2/start" target="_blank">Authorize your GCS credential</a></li>
146
+ <li><a href="https://huggingface.co/spaces/BinKhoaLe1812/MedAI_Processing/blob/main/REQUEST.md" target="_blank">📑 Request Doc (all curl examples)</a></li>
147
+ </ul>
148
+ </div>
149
+
150
+ <div class="section">
151
+ <h2>📝 Log</h2>
152
+ <div id="log">Click a button above to run a job...</div>
153
+ </div>
154
+
155
+ <script>
156
+ async function startJob(dataset) {{
157
+ const log = document.getElementById("log");
158
+ const ragToggle = document.getElementById("ragToggle");
159
+ const isRagMode = ragToggle.checked;
160
+
161
+ log.innerHTML = "⏳ Starting " + (isRagMode ? "RAG " : "") + "job for <b>" + dataset + "</b>...";
162
+ try {{
163
+ const resp = await fetch("/process/" + dataset, {{
164
+ method: "POST",
165
+ headers: {{ "Content-Type": "application/json" }},
166
+ body: JSON.stringify({{
167
+ augment: {{
168
+ paraphrase_ratio: 0.1,
169
+ backtranslate_ratio: 0.00, // Increase to 0.05-0.1 for back-translation
170
+ paraphrase_outputs: false,
171
+ style_standardize: true,
172
+ deidentify: true,
173
+ dedupe: true,
174
+ max_chars: 5000,
175
+ expand: true,
176
+ max_aug_per_sample: 2
177
+ }},
178
+ sample_limit: null, // Sample down (currently disabled)
179
+ seed: 42,
180
+ rag_processing: isRagMode
181
+ }})
182
+ }});
183
+ const data = await resp.json();
184
+ if (resp.ok) {{
185
+ log.innerHTML = "✅ " + JSON.stringify(data);
186
+ }} else {{
187
+ log.innerHTML = "❌ Error: " + JSON.stringify(data);
188
+ }}
189
+ }} catch (err) {{
190
+ log.innerHTML = "❌ JS Error: " + err;
191
+ }}
192
+ }}
193
+
194
+ async function startRagJob(dataset) {{
195
+ const log = document.getElementById("log");
196
+ log.innerHTML = "⏳ Starting RAG processing for <b>" + dataset + "</b>...";
197
+ try {{
198
+ const resp = await fetch("/rag/" + dataset, {{
199
+ method: "POST",
200
+ headers: {{ "Content-Type": "application/json" }},
201
+ body: JSON.stringify({{
202
+ sample_limit: null,
203
+ seed: 42
204
+ }})
205
+ }});
206
+ const data = await resp.json();
207
+ if (resp.ok) {{
208
+ log.innerHTML = "✅ RAG Processing Started: " + JSON.stringify(data);
209
+ }} else {{
210
+ log.innerHTML = "❌ Error: " + JSON.stringify(data);
211
+ }}
212
+ }} catch (err) {{
213
+ log.innerHTML = "❌ JS Error: " + err;
214
+ }}
215
+ }}
216
+ </script>
217
+ </body>
218
+ </html>
219
+ """
220
+
221
+ @app.get("/status")
222
+ def status():
223
+ with STATE_LOCK:
224
+ return JSONResponse(STATE)
225
+
226
+ # ──────── GCS token ────────
227
+ @app.get("/oauth2/start")
228
+ def oauth2_start(request: Request):
229
+ # Compute redirect URI dynamically from the actual host the Space is using
230
+ host = request.headers.get("x-forwarded-host") or request.headers.get("host")
231
+ scheme = "https" # Spaces are HTTPS at the edge
232
+ redirect_uri = f"{scheme}://{host}/oauth2/callback"
233
+
234
+ try:
235
+ url = build_auth_url(redirect_uri)
236
+ return JSONResponse({"authorize_url": url})
237
+ except Exception as e:
238
+ raise HTTPException(500, f"OAuth init failed: {e}")
239
+
240
+ # Display your token
241
+ @app.get("/oauth2/callback")
242
+ def oauth2_callback(request: Request, code: str = "", state: str = ""):
243
+ if not code:
244
+ raise HTTPException(400, "Missing 'code'")
245
+ # Send req
246
+ host = request.headers.get("x-forwarded-host") or request.headers.get("host")
247
+ scheme = "https"
248
+ redirect_uri = f"{scheme}://{host}/oauth2/callback"
249
+ # Parse and show token code
250
+ try:
251
+ creds = exchange_code(code, redirect_uri)
252
+ refresh = creds.refresh_token or os.getenv("GDRIVE_REFRESH_TOKEN", "")
253
+ # UI
254
+ html = f"""
255
+ <html>
256
+ <head>
257
+ <style>
258
+ body {{ font-family: sans-serif; margin: 2em; }}
259
+ .token-box {{
260
+ padding: 1em; border: 1px solid #ccc; border-radius: 6px;
261
+ background: #f9f9f9; font-family: monospace;
262
+ word-break: break-all; white-space: pre-wrap;
263
+ }}
264
+ .note {{ margin-top: 1em; color: #555; }}
265
+ </style>
266
+ </head>
267
+ <body>
268
+ <h2>✅ Google Drive Authorized</h2>
269
+ <p>Your refresh token is:</p>
270
+ <div class="token-box">{refresh}</div>
271
+ <p class="note">
272
+ 👉 Copy this token and save it into your Hugging Face Space Secrets
273
+ as <code>GDRIVE_REFRESH_TOKEN</code>.
274
+ This ensures persistence across rebuilds.
275
+ </p>
276
+ </body>
277
+ </html>
278
+ """
279
+ return HTMLResponse(html)
280
+ except Exception as e:
281
+ raise HTTPException(500, f"OAuth exchange failed: {e}")
282
+
283
+ @app.get("/files")
284
+ def files():
285
+ out = []
286
+ for root, _, fns in os.walk(OUTPUT_DIR):
287
+ for fn in fns:
288
+ out.append(os.path.relpath(os.path.join(root, fn), OUTPUT_DIR))
289
+ return {"output_dir": OUTPUT_DIR, "files": sorted(out)}
290
+
291
+ @app.post("/process/{dataset_key}")
292
+ def process_dataset(dataset_key: str, params: ProcessParams, background: BackgroundTasks):
293
+ with STATE_LOCK:
294
+ if STATE["running"]:
295
+ logger.warning(
296
+ f"[JOB] Rejecting new job dataset={dataset_key} "
297
+ f"current={STATE['dataset']} started_at={STATE['started_at']}"
298
+ )
299
+ raise HTTPException(409, detail="Another job is running.")
300
+ STATE["running"] = True
301
+ STATE["dataset"] = dataset_key
302
+ STATE["started_at"] = now_iso()
303
+ STATE["progress"] = 0.0
304
+ STATE["message"] = "starting"
305
+ STATE["last_result"] = None
306
+ logger.info(
307
+ f"[JOB] Queued dataset={dataset_key} "
308
+ f"params={{'sample_limit': {params.sample_limit}, 'seed': {params.seed}, "
309
+ f"'rag_processing': {params.rag_processing}, 'augment': {params.augment.dict()} }}"
310
+ )
311
+ # Start job to background runner thread
312
+ logger.info(f"[JOB] Started dataset={dataset_key}")
313
+ background.add_task(_run_job, dataset_key, params)
314
+ return {"ok": True, "message": f"Job for '{dataset_key}' started."}
315
+
316
+ @app.post("/rag/{dataset_key}")
317
+ def process_rag_dataset(dataset_key: str, params: ProcessParams, background: BackgroundTasks):
318
+ """Dedicated RAG processing endpoint"""
319
+ # Force RAG processing mode
320
+ params.rag_processing = True
321
+
322
+ with STATE_LOCK:
323
+ if STATE["running"]:
324
+ logger.warning(
325
+ f"[RAG] Rejecting new RAG job dataset={dataset_key} "
326
+ f"current={STATE['dataset']} started_at={STATE['started_at']}"
327
+ )
328
+ raise HTTPException(409, detail="Another job is running.")
329
+ STATE["running"] = True
330
+ STATE["dataset"] = dataset_key
331
+ STATE["started_at"] = now_iso()
332
+ STATE["progress"] = 0.0
333
+ STATE["message"] = "starting RAG processing"
334
+ STATE["last_result"] = None
335
+ logger.info(
336
+ f"[RAG] Queued RAG dataset={dataset_key} "
337
+ f"params={{'sample_limit': {params.sample_limit}, 'seed': {params.seed} }}"
338
+ )
339
+ # Start job to background runner thread
340
+ logger.info(f"[RAG] Started RAG dataset={dataset_key}")
341
+ background.add_task(_run_job, dataset_key, params)
342
+ return {"ok": True, "message": f"RAG processing job for '{dataset_key}' started."}
343
+
344
+ def _run_job(dataset_key: str, params: ProcessParams):
345
+ t0 = time.time()
346
+ try:
347
+ ds = resolve_dataset(dataset_key)
348
+ if not ds:
349
+ set_state(running=False, message="unknown dataset")
350
+ return
351
+
352
+ # Download HF Dataset and start processing units
353
+ set_state(message="downloading")
354
+ local_path = hf_download_dataset(ds["repo_id"], ds["filename"], ds["repo_type"])
355
+ logger.info(f"[JOB] Downloaded {ds['repo_id']}/{ds['filename']} → {local_path}")
356
+
357
+ # Prepare timestamp for fire writing
358
+ ts = dt.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
359
+ mode_suffix = "rag" if params.rag_processing else "sft"
360
+ stem = f"{dataset_key}-{mode_suffix}-{ts}"
361
+ jsonl_path = os.path.join(OUTPUT_DIR, f"{stem}.jsonl")
362
+ csv_path = os.path.join(OUTPUT_DIR, f"{stem}.csv")
363
+ # Change state
364
+ set_state(message="processing", progress=0.05)
365
+
366
+ # Writer
367
+ writer = CentralisedWriter(jsonl_path=jsonl_path, csv_path=csv_path)
368
+
369
+ if params.rag_processing:
370
+ # RAG processing mode
371
+ set_state(message="RAG processing", progress=0.1)
372
+ count, stats = process_file_into_rag(
373
+ dataset_key=dataset_key,
374
+ input_path=local_path,
375
+ writer=writer,
376
+ nvidia_model=os.getenv("NVIDIA_MODEL", "meta/llama-3.1-8b-instruct"),
377
+ sample_limit=params.sample_limit,
378
+ seed=params.seed,
379
+ progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"])
380
+ )
381
+ else:
382
+ # Standard SFT processing mode
383
+ set_state(message="SFT processing", progress=0.1)
384
+ count, stats = process_file_into_sft(
385
+ dataset_key=dataset_key,
386
+ input_path=local_path,
387
+ writer=writer,
388
+ paraphraser=paraphraser,
389
+ augment_opts=params.augment.dict(),
390
+ sample_limit=params.sample_limit,
391
+ seed=params.seed,
392
+ progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"])
393
+ )
394
+ logger.info(f"[JOB] Processed dataset={dataset_key} rows={count} stats={stats}")
395
+ writer.close()
396
+
397
+ # Upload to GDrive
398
+ set_state(message="uploading to Google Drive", progress=0.95)
399
+ up1 = drive.upload_file_to_drive(jsonl_path, mimetype="application/json")
400
+ up2 = drive.upload_file_to_drive(csv_path, mimetype="text/csv")
401
+ logger.info(
402
+ f"[JOB] Uploads complete uploaded={bool(up1 and up2)} "
403
+ f"jsonl={jsonl_path} csv={csv_path}"
404
+ )
405
+
406
+ # Finalize a task
407
+ result = {
408
+ "dataset": dataset_key,
409
+ "processing_mode": "RAG" if params.rag_processing else "SFT",
410
+ "processed_rows": count,
411
+ "stats": stats,
412
+ "artifacts": {"jsonl": jsonl_path, "csv": csv_path},
413
+ "uploaded": bool(up1 and up2),
414
+ "duration_sec": round(time.time() - t0, 2)
415
+ }
416
+ set_state(message="done", progress=1.0, last_result=result, running=False)
417
+ logger.info(
418
+ f"[JOB] Finished dataset={dataset_key} "
419
+ f"duration_sec={round(time.time()-t0, 2)}"
420
+ )
421
+ except Exception as e:
422
+ logger.exception(f"[JOB] Error for dataset={dataset_key}: {e}")
423
+ set_state(message=f"error: {e}", running=False)
mount_drive.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Check Google Drive status
2
+ from utils.drive_saver import DriveSaver
3
+
4
+ if __name__ == "__main__":
5
+ ds = DriveSaver()
6
+ if ds.is_service_available():
7
+ print("Drive ready.")
8
+ else:
9
+ print("Drive NOT ready.")
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ python-dotenv
4
+ huggingface_hub
5
+ requests
6
+ google-genai
7
+ google-api-python-client
8
+ google-auth
9
+ google-auth-httplib2
10
+ google-auth-oauthlib
11
+ orjson
12
+ ftfy
13
+ langid
utils/ __init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility package for the Medical Dataset Augmenter Space.
3
+
4
+ This package provides:
5
+ - drive_saver: Google Drive upload helper
6
+ - llm: API key rotation, paraphraser, translation/backtranslation
7
+ - datasets: Hugging Face dataset resolver & downloader
8
+ - processor: dataset-specific processing pipeline with augmentation
9
+ - schema: centralised SFT writer (JSONL + CSV)
10
+ - token: GCS project token refresher and authenticator
11
+ - augment: low-level augmentation utilities (text cleanup, deid, paraphrase hooks)
12
+ """
13
+
14
+ from . import drive_saver
15
+ from . import llm
16
+ from . import datasets
17
+ from . import processor
18
+ from . import schema
19
+ from . import augment
20
+ from . import token
21
+
22
+ __all__ = ["drive_saver", "llm", "datasets", "processor", "schema", "augment"]
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/augment.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # augmentation utility agent
2
+ import re
3
+ import random
4
+ from typing import Dict, Tuple
5
+ import ftfy
6
+ import langid
7
+
8
+ P_EMAIL = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}")
9
+ P_PHONE = re.compile(r"(?:(?:\+?\d{1,3})?[\s-]?)?(?:\(?\d{2,4}\)?[\s-]?)?\d{3,4}[\s-]?\d{3,4}")
10
+ P_URL = re.compile(r"https?://\S+|www\.\S+")
11
+ P_IP = re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b")
12
+
13
+ def fix_unicode(s: str) -> str:
14
+ return ftfy.fix_text(s or "")
15
+
16
+ def normalize_whitespace(s: str) -> str:
17
+ s = s.replace("\u00A0", " ")
18
+ s = re.sub(r"[ \t]+", " ", s)
19
+ s = re.sub(r"\s+\n", "\n", s)
20
+ s = re.sub(r"\n{3,}", "\n\n", s)
21
+ return s.strip()
22
+
23
+ def canonicalize_quotes(s: str) -> str:
24
+ return s.replace("“", '"').replace("”", '"').replace("’", "'").replace("‘", "'")
25
+
26
+ def ensure_terminal_punct(s: str) -> str:
27
+ if not s: return s
28
+ if s[-1] in ".!?": return s
29
+ return s + "."
30
+
31
+ def deidentify(s: str) -> str:
32
+ s = P_EMAIL.sub("[REDACTED_EMAIL]", s)
33
+ s = P_PHONE.sub("[REDACTED_PHONE]", s)
34
+ s = P_URL.sub("[REDACTED_URL]", s)
35
+ s = P_IP.sub("[REDACTED_IP]", s)
36
+ return s
37
+
38
+ def lang_is_english(s: str) -> bool:
39
+ try:
40
+ lang, _ = langid.classify((s or "")[:2000])
41
+ return lang == "en"
42
+ except Exception:
43
+ return True
44
+
45
+ def length_cap(s: str, max_chars: int) -> str:
46
+ if len(s) <= max_chars:
47
+ return s
48
+ # try to cut at sentence boundary
49
+ cut = s[:max_chars]
50
+ last_dot = cut.rfind(". ")
51
+ if last_dot > 300: # don't cut too aggressively
52
+ return cut[:last_dot+1] + " …"
53
+ return cut + " …"
54
+
55
+ def fingerprint(instr: str, user: str, out: str) -> str:
56
+ # Simple, fast fingerprint for dedupe
57
+ def norm(x: str) -> str:
58
+ x = x.lower()
59
+ x = re.sub(r"[^a-z0-9]+", " ", x)
60
+ x = re.sub(r"\s+", " ", x).strip()
61
+ return x
62
+ core = "||".join([norm(instr), norm(user), norm(out)])
63
+ # lightweight hash
64
+ import hashlib
65
+ return hashlib.md5(core.encode("utf-8")).hexdigest()
66
+
67
+ def style_standardize_answer(ans: str) -> str:
68
+ if not ans: return ans
69
+ ans = ans.strip()
70
+ # Gentle guardrails, neutral voice
71
+ prefix = ""
72
+ # Avoid absolute guarantees
73
+ ans = re.sub(r"\b(guarantee|100%|certainly|always|never)\b", "likely", ans, flags=re.I)
74
+ # Remove sign-offs typical of forums
75
+ ans = re.sub(r"\n*(thanks|thank you|regards|cheers)[^\n]*$", "", ans, flags=re.I)
76
+ return ensure_terminal_punct(ans)
77
+
78
+ def base_cleanup(s: str, max_chars: int, do_deid: bool) -> str:
79
+ s = fix_unicode(s)
80
+ s = canonicalize_quotes(s)
81
+ s = normalize_whitespace(s)
82
+ if do_deid:
83
+ s = deidentify(s)
84
+ s = length_cap(s, max_chars)
85
+ return s
86
+
87
+ def maybe_paraphrase(text: str, ratio: float, paraphraser, difficulty: str) -> Tuple[str, bool]:
88
+ if ratio <= 0 or not text: return text, False
89
+ if random.random() < ratio:
90
+ return paraphraser.paraphrase(text, difficulty=difficulty), True
91
+ return text, False
92
+
93
+ def maybe_backtranslate(text: str, ratio: float, paraphraser) -> Tuple[str, bool]:
94
+ if ratio <= 0 or not text: return text, False
95
+ if random.random() < ratio:
96
+ bt = paraphraser.backtranslate(text, via_lang="de")
97
+ return bt if bt else text, bool(bt)
98
+ return text, False
99
+
100
+ def consistency_ok(user: str, out: str, ratio: float, paraphraser) -> bool:
101
+ if ratio <= 0 or (not user) or (not out):
102
+ return True
103
+ if random.random() >= ratio:
104
+ return True
105
+ return paraphraser.consistency_check(user, out)
utils/datasets.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF dataset download resolver + downloader
2
+ import os
3
+ from typing import Optional
4
+ from huggingface_hub import hf_hub_download
5
+ import logging
6
+
7
+ # Logger
8
+ logger = logging.getLogger("datasets")
9
+ if not logger.handlers:
10
+ logger.setLevel(logging.INFO)
11
+ logger.addHandler(logging.StreamHandler())
12
+
13
+
14
+ DATASETS = {
15
+ "healthcaremagic": {
16
+ "repo_id": "BinKhoaLe1812/MedDialog-EN-100k",
17
+ "filename": "HealthCareMagic-100k.json",
18
+ "repo_type": "dataset"
19
+ },
20
+ "icliniq": {
21
+ "repo_id": "BinKhoaLe1812/MedDialog-EN-10k",
22
+ "filename": "iCliniq.json",
23
+ "repo_type": "dataset"
24
+ },
25
+ "pubmedqa_l": {
26
+ "repo_id": "BinKhoaLe1812/PubMedQA-L",
27
+ "filename": "ori_pqal.json",
28
+ "repo_type": "dataset"
29
+ },
30
+ "pubmedqa_u": {
31
+ "repo_id": "BinKhoaLe1812/PubMedQA-U",
32
+ "filename": "ori_pqau.json",
33
+ "repo_type": "dataset"
34
+ },
35
+ "pubmedqa_map": {
36
+ "repo_id": "BinKhoaLe1812/PubMedQA-Map",
37
+ "filename": "pubmed_qa_map.json",
38
+ "repo_type": "dataset"
39
+ }
40
+ }
41
+
42
+
43
+ def resolve_dataset(key: str) -> Optional[dict]:
44
+ return DATASETS.get(key.lower())
45
+
46
+
47
+ def hf_download_dataset(repo_id: str, filename: str, repo_type: str = "dataset") -> str:
48
+ token = os.getenv("HF_TOKEN")
49
+ logger.info(
50
+ f"[HF] Download {repo_id}/{filename} (type={repo_type}) token={'yes' if token else 'no'}"
51
+ )
52
+ path = hf_hub_download(
53
+ repo_id=repo_id,
54
+ filename=filename,
55
+ repo_type=repo_type,
56
+ token=token,
57
+ local_dir=os.path.abspath("cache/hf"),
58
+ local_dir_use_symlinks=False
59
+ )
60
+ try:
61
+ size = os.path.getsize(path)
62
+ logger.info(f"[HF] Downloaded to {path} size={size} bytes")
63
+ except Exception:
64
+ logger.info(f"[HF] Downloaded to {path}")
65
+ return path
66
+
utils/drive_saver.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Save final post-process to Google Drive
2
+ import os, json, logging
3
+ from typing import Optional
4
+ from google.oauth2 import service_account
5
+ from googleapiclient.discovery import build
6
+ from googleapiclient.http import MediaFileUpload
7
+
8
+ from utils.token import get_credentials
9
+
10
+ logger = logging.getLogger("dsaver")
11
+ if not logger.handlers:
12
+ logger.setLevel(logging.INFO)
13
+ fmt = logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s")
14
+ handler = logging.StreamHandler()
15
+ handler.setFormatter(fmt)
16
+ logger.addHandler(handler)
17
+
18
+ class DriveSaver:
19
+ """Google Drive uploader. Prefers OAuth; optional SA fallback (Shared Drive only)."""
20
+
21
+ def __init__(self, default_folder_id: Optional[str] = None):
22
+ self.service = None
23
+ self.folder_id = default_folder_id or os.getenv("GDRIVE_FOLDER_ID")
24
+ self.supports_all_drives = os.getenv("GDRIVE_FOLDER_IS_SHARED", "false").lower() in ("1","true","yes")
25
+ self.allow_sa_fallback = os.getenv("GDRIVE_ALLOW_SA_FALLBACK", "false").lower() in ("1","true","yes")
26
+ if not self.folder_id:
27
+ logger.warning("📁 No GDRIVE_FOLDER_ID set; uploads must provide folder_id explicitly")
28
+ self._initialize_service()
29
+
30
+ def _initialize_service(self):
31
+ creds = get_credentials()
32
+ if creds:
33
+ logger.info("✅ Using OAuth credentials")
34
+ else:
35
+ # Optional SA fallback — ONLY valid for Shared Drives where SA is a member
36
+ if self.allow_sa_fallback:
37
+ creds_env = os.getenv("GDRIVE_CREDENTIALS_JSON")
38
+ if creds_env:
39
+ try:
40
+ info = json.loads(creds_env)
41
+ if info.get("type") == "service_account":
42
+ creds = service_account.Credentials.from_service_account_info(
43
+ info, scopes=["https://www.googleapis.com/auth/drive"]
44
+ )
45
+ logger.info("✅ Using Service Account credentials (fallback)")
46
+ if not self.supports_all_drives:
47
+ logger.warning("⚠️ SA fallback without Shared Drive mode will likely fail (no quota). "
48
+ "Set GDRIVE_FOLDER_IS_SHARED=true and use a Shared Drive folder ID.")
49
+ else:
50
+ logger.error("❌ GDRIVE_CREDENTIALS_JSON is not a service account JSON")
51
+ except Exception as e:
52
+ logger.error(f"❌ Failed to init Service Account: {e}")
53
+ if not creds:
54
+ logger.error("❌ No valid Google credentials available (OAuth or SA).")
55
+ self.service = None
56
+ return
57
+ # Build Drive service
58
+ self.service = build("drive", "v3", credentials=creds)
59
+ logger.info("✅ Google Drive service initialized")
60
+
61
+ def upload_file_to_drive(self, file_path: str, folder_id: Optional[str] = None, mimetype: Optional[str] = None) -> bool:
62
+ if not self.service:
63
+ logger.error("❌ Drive service not initialized")
64
+ return False
65
+ try:
66
+ target_folder = folder_id or self.folder_id
67
+ name = os.path.basename(file_path)
68
+ media = MediaFileUpload(file_path, mimetype=mimetype or "application/octet-stream")
69
+ metadata = {"name": name, "parents": [target_folder]}
70
+ req = self.service.files().create(
71
+ body=metadata,
72
+ media_body=media,
73
+ fields="id",
74
+ supportsAllDrives=self.supports_all_drives
75
+ )
76
+ req.execute()
77
+ logger.info(f"✅ Uploaded '{name}' to Drive (folder: {target_folder})")
78
+ return True
79
+ except Exception as e:
80
+ logger.error(f"❌ Drive upload failed: {e}")
81
+ return False
82
+
83
+ def is_service_available(self) -> bool:
84
+ return self.service is not None
85
+
86
+ def set_folder_id(self, folder_id: str):
87
+ self.folder_id = folder_id
88
+ logger.info(f"📁 Default folder ID updated: {folder_id}")
utils/llm.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Round-robin rotator + paraphrasing + translation/backtranslation
2
+ import os
3
+ import logging
4
+ import requests
5
+ from typing import Optional
6
+ from google import genai
7
+
8
+ logger = logging.getLogger("llm")
9
+ if not logger.handlers:
10
+ logger.setLevel(logging.INFO)
11
+ handler = logging.StreamHandler()
12
+ logger.addHandler(handler)
13
+
14
+ # LLM parser limit text to log-out
15
+ def snip(s: str, n: int = 12) -> str:
16
+ if not isinstance(s, str): return "∅"
17
+ parts = s.strip().split()
18
+ return " ".join(parts[:n]) + (" …" if len(parts) > n else "")
19
+
20
+ class KeyRotator:
21
+ def __init__(self, env_prefix: str, max_keys: int = 5):
22
+ keys = []
23
+ for i in range(1, max_keys + 1):
24
+ v = os.getenv(f"{env_prefix}_{i}")
25
+ if v:
26
+ keys.append(v.strip())
27
+ if not keys:
28
+ logger.warning(f"[LLM] No keys found for prefix {env_prefix}_*")
29
+ self.keys = keys
30
+ self.dead = set()
31
+ self.idx = 0
32
+
33
+ def next_key(self) -> Optional[str]:
34
+ if not self.keys:
35
+ return None
36
+ for _ in range(len(self.keys)):
37
+ k = self.keys[self.idx % len(self.keys)]
38
+ self.idx += 1
39
+ if k not in self.dead:
40
+ return k
41
+ return None
42
+
43
+ def mark_bad(self, key: Optional[str]):
44
+ if key:
45
+ self.dead.add(key)
46
+ logger.warning(f"[LLM] Quarantined key (prefix hidden): {key[:6]}***")
47
+
48
+ class GeminiClient:
49
+ def __init__(self, rotator: KeyRotator, default_model: str):
50
+ self.rotator = rotator
51
+ self.default_model = default_model
52
+
53
+ def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_output_tokens: int = 512) -> Optional[str]:
54
+ key = self.rotator.next_key()
55
+ if not key:
56
+ return None
57
+ try:
58
+ client = genai.Client(api_key=key)
59
+ # NOTE: matches your required pattern/use
60
+ res = client.models.generate_content(
61
+ model=model or self.default_model,
62
+ contents=prompt
63
+ )
64
+ text = getattr(res, "text", None)
65
+ if text:
66
+ logger.info(f"[LLM][Gemini] out={snip(text)}")
67
+ return text
68
+ except Exception as e:
69
+ logger.error(f"[LLM][Gemini] {e}")
70
+ self.rotator.mark_bad(key)
71
+ return None
72
+
73
+ class NvidiaClient:
74
+ def __init__(self, rotator: KeyRotator, default_model: str):
75
+ self.rotator = rotator
76
+ self.default_model = default_model
77
+ self.url = os.getenv("NVIDIA_API_URL", "https://integrate.api.nvidia.com/v1/chat/completions")
78
+
79
+ # Regex-based cleaning resp from quotes
80
+ def _clean_resp(self, resp: str) -> str:
81
+ if not resp: return resp
82
+ txt = resp.strip()
83
+ # Remove common boilerplate prefixes
84
+ for pat in [
85
+ r"^Here is (a|the) .*?:\s*",
86
+ r"^Paraphrased(?: version)?:\s*",
87
+ r"^Sure[,.]?\s*",
88
+ r"^Okay[,.]?\s*"
89
+ ]:
90
+ import re
91
+ txt = re.sub(pat, "", txt, flags=re.I)
92
+ return txt.strip()
93
+
94
+ def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_tokens: int = 512) -> Optional[str]:
95
+ key = self.rotator.next_key()
96
+ if not key:
97
+ return None
98
+ try:
99
+ headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
100
+ payload = {
101
+ "model": model or self.default_model,
102
+ "messages": [{"role": "user", "content": prompt}],
103
+ "temperature": temperature,
104
+ "max_tokens": max_tokens
105
+ }
106
+ r = requests.post(self.url, headers=headers, json=payload, timeout=45)
107
+ if r.status_code >= 400:
108
+ raise RuntimeError(f"HTTP {r.status_code}: {r.text[:200]}")
109
+ data = r.json()
110
+ text = data["choices"][0]["message"]["content"]
111
+ clean = self._clean_resp(text)
112
+ logger.info(f"[LLM][NVIDIA] out={snip(clean)}")
113
+ return clean
114
+ except Exception as e:
115
+ logger.error(f"[LLM][NVIDIA] {e}")
116
+ self.rotator.mark_bad(key)
117
+ return None
118
+
119
+ class Paraphraser:
120
+ """Prefers NVIDIA (cheap), falls back to Gemini. Also offers translate/backtranslate and a tiny consistency judge."""
121
+ def __init__(self, nvidia_model: str, gemini_model_easy: str, gemini_model_hard: str):
122
+ self.nv = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
123
+ self.gm_easy = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_easy)
124
+ self.gm_hard = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_hard)
125
+
126
+ # Regex-based cleaning resp from quotes
127
+ def _clean_resp(self, resp: str) -> str:
128
+ if not resp: return resp
129
+ txt = resp.strip()
130
+ # Remove common boilerplate prefixes
131
+ for pat in [
132
+ r"^Here is (a|the) .*?:\s*",
133
+ r"^Paraphrased(?: version)?:\s*",
134
+ r"^Sure[,.]?\s*",
135
+ r"^Okay[,.]?\s*"
136
+ ]:
137
+ import re
138
+ txt = re.sub(pat, "", txt, flags=re.I)
139
+ return txt.strip()
140
+
141
+ # ————— Paraphrase —————
142
+ def paraphrase(self, text: str, difficulty: str = "easy") -> str:
143
+ if not text or len(text) < 12:
144
+ return text
145
+ prompt = (
146
+ "Paraphrase the following medical text concisely, preserve meaning and clinical terms.\n"
147
+ "Do not fabricate or remove factual claims.\n"
148
+ "Return ONLY the rewritten text, without any introduction, commentary.\n"+ text
149
+ )
150
+ out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(600, max(128, len(text)//2)))
151
+ if out: return self._clean_resp(out)
152
+ gm = self.gm_easy if difficulty == "easy" else self.gm_hard
153
+ out = gm.generate(prompt, max_output_tokens=min(600, max(128, len(text)//2)))
154
+ return self._clean_resp(out) if out else text
155
+
156
+ # ————— Translate & Backtranslate —————
157
+ def translate(self, text: str, target_lang: str = "de") -> Optional[str]:
158
+ if not text: return text
159
+ prompt = f"Translate to {target_lang}. Keep meaning exact, preserve medical terms:\n\n{text}"
160
+ out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
161
+ if out: return out.strip()
162
+ return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
163
+
164
+ def backtranslate(self, text: str, via_lang: str = "de") -> Optional[str]:
165
+ if not text: return text
166
+ mid = self.translate(text, target_lang=via_lang)
167
+ if not mid: return None
168
+ prompt = f"Translate the following {via_lang} text back to English, preserving the exact meaning:\n\n{mid}"
169
+ out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
170
+ if out: return out.strip()
171
+ res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
172
+ return res.strip() if res else None
173
+
174
+ # ————— Consistency Judge (cheap, ratio-based) —————
175
+ def consistency_check(self, user: str, output: str) -> bool:
176
+ """Return True if 'output' appears supported by 'user' (context/question). Soft heuristic via LLM."""
177
+ prompt = (
178
+ "You are a strict medical QA validator. Given the USER input (question+context) "
179
+ "and the MODEL ANSWER, reply with exactly 'PASS' if the answer is supported and safe, "
180
+ "otherwise 'FAIL'. No extra text.\n\n"
181
+ f"USER:\n{user}\n\nANSWER:\n{output}"
182
+ )
183
+ out = self.nv.generate(prompt, temperature=0.0, max_tokens=3)
184
+ if not out:
185
+ out = self.gm_easy.generate(prompt, max_output_tokens=3)
186
+ return isinstance(out, str) and "PASS" in out.upper()
utils/processor.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset-specific parsers + paraphrasing flow
2
+ import json
3
+ import random
4
+ import hashlib
5
+ import logging
6
+ from typing import Callable, Optional, Dict, Tuple
7
+
8
+ from utils.schema import sft_row
9
+ from utils import augment as A
10
+
11
+ # Logger
12
+ logger = logging.getLogger("processor")
13
+ if not logger.handlers:
14
+ logger.setLevel(logging.INFO)
15
+ logger.addHandler(logging.StreamHandler())
16
+
17
+
18
+ def _hash_id(*parts) -> str:
19
+ h = hashlib.sha256()
20
+ for p in parts:
21
+ h.update(str(p).encode("utf-8"))
22
+ return h.hexdigest()[:16]
23
+
24
+ def _iter_json_or_jsonl(path: str):
25
+ with open(path, "r", encoding="utf-8") as f:
26
+ first = f.read(1); f.seek(0)
27
+ if first == "[":
28
+ data = json.load(f)
29
+ for obj in data: yield obj
30
+ else:
31
+ for line in f:
32
+ line = line.strip()
33
+ if line: yield json.loads(line)
34
+
35
+ def process_file_into_sft(
36
+ dataset_key: str,
37
+ input_path: str,
38
+ writer,
39
+ paraphraser,
40
+ augment_opts: Dict,
41
+ sample_limit: Optional[int],
42
+ seed: int,
43
+ progress_cb: Optional[Callable[[float, str], None]]
44
+ ) -> Tuple[int, Dict]:
45
+ random.seed(seed)
46
+ stats = {
47
+ "written": 0,
48
+ "paraphrased_input": 0,
49
+ "paraphrased_output": 0,
50
+ "backtranslated_input": 0,
51
+ "backtranslated_output": 0,
52
+ "dedup_skipped": 0,
53
+ "consistency_failed": 0
54
+ }
55
+ # Start processing SFT
56
+ key_summary = {k: augment_opts.get(k) for k in (
57
+ "paraphrase_ratio","backtranslate_ratio","paraphrase_outputs",
58
+ "style_standardize","deidentify","dedupe",
59
+ "consistency_check_ratio","distill_fraction"
60
+ )}
61
+ logger.info(
62
+ f"[PROC] Begin dataset={dataset_key} sample_limit={sample_limit} opts={key_summary}"
63
+ )
64
+ # If deduplicating enabled
65
+ dedupe_seen = set() if augment_opts.get("dedupe", True) else None
66
+
67
+ key = dataset_key.lower()
68
+ if key in ("healthcaremagic", "icliniq"):
69
+ count = _proc_med_dialog(source=key, path=input_path, writer=writer,
70
+ paraphraser=paraphraser, opts=augment_opts,
71
+ sample_limit=sample_limit, stats=stats, cb=progress_cb, dedupe_seen=dedupe_seen)
72
+ elif key == "pubmedqa_l":
73
+ count = _proc_pubmedqa_l(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen)
74
+ elif key == "pubmedqa_u":
75
+ count = _proc_pubmedqa_u(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen)
76
+ elif key == "pubmedqa_map":
77
+ count = _proc_pubmedqa_map(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen)
78
+ else:
79
+ raise ValueError(f"Unknown dataset: {dataset_key}")
80
+ logger.info(f"[PROC] End dataset={dataset_key} stats={stats}")
81
+ return count, stats
82
+
83
+ # ——————————— helpers ———————————
84
+ def _build_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict):
85
+ """Return a list of (user_variant, out_variant, applied_tags) not including the original."""
86
+ variants = []
87
+ max_k = max(0, int(opts.get("max_aug_per_sample", 1)))
88
+ for _ in range(max_k):
89
+ applied = []
90
+ u2, did_p = A.maybe_paraphrase(user, opts.get("paraphrase_ratio", 0.0), paraphraser, "easy")
91
+ if did_p: applied.append("paraphrase_input"); stats["paraphrased_input"] += 1
92
+ u3, did_bt = A.maybe_backtranslate(u2, opts.get("backtranslate_ratio", 0.0), paraphraser)
93
+ if did_bt: applied.append("backtranslate_input"); stats["backtranslated_input"] += 1
94
+
95
+ o3 = out
96
+ if opts.get("paraphrase_outputs", False):
97
+ o2, did_p2 = A.maybe_paraphrase(out, opts.get("paraphrase_ratio", 0.0), paraphraser, "hard")
98
+ if did_p2: applied.append("paraphrase_output"); stats["paraphrased_output"] += 1
99
+ o3b, did_bt2 = A.maybe_backtranslate(o2, opts.get("backtranslate_ratio", 0.0), paraphraser)
100
+ if did_bt2: applied.append("backtranslate_output"); stats["backtranslated_output"] += 1
101
+ o3 = o3b
102
+
103
+ # If nothing applied, skip this variant
104
+ if not applied:
105
+ continue
106
+ # Style standardize and punctuation for the variant too
107
+ if opts.get("style_standardize", True):
108
+ o3 = A.style_standardize_answer(o3)
109
+ u3 = A.ensure_terminal_punct(u3) if u3 else u3
110
+ o3 = A.ensure_terminal_punct(o3) if o3 else o3
111
+ variants.append((u3, o3, applied))
112
+ return variants
113
+
114
+ def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphraser, stats: Dict):
115
+ # Base cleanup & caps (returns cleaned strings)
116
+ user = A.base_cleanup(user, opts.get("max_chars", 5000), opts.get("deidentify", True))
117
+ out = A.base_cleanup(out, opts.get("max_chars", 5000), opts.get("deidentify", True))
118
+ instr = A.base_cleanup(instr, opts.get("max_chars", 5000), False)
119
+
120
+ # Language sanity (mostly English—skip aggressive transforms if not)
121
+ if not A.lang_is_english(user): # very rare
122
+ return instr, user, out, []
123
+
124
+ # Stack list of entries that has been applied augmentation and stylings
125
+ applied = []
126
+
127
+ # Style standardizing the answer
128
+ if opts.get("style_standardize", True):
129
+ out = A.style_standardize_answer(out)
130
+ applied.append("style_standardize")
131
+
132
+ # Ensure punctuation/whitespace
133
+ user = A.ensure_terminal_punct(user) if user else user
134
+ out = A.ensure_terminal_punct(out) if out else out
135
+
136
+ return instr, user, out, applied
137
+
138
+ def _commit_row(writer, source, rid, task, instr, user, out, opts, stats, aug_applied, extra_meta=None, dedupe_seen=None):
139
+ # Dedup entry
140
+ if dedupe_seen is not None:
141
+ fp = A.fingerprint(instr, user, out)
142
+ if fp in dedupe_seen:
143
+ stats["dedup_skipped"] += 1
144
+ return False
145
+ dedupe_seen.add(fp)
146
+
147
+ meta = {"augmentations": aug_applied}
148
+ if extra_meta:
149
+ meta.update(extra_meta)
150
+
151
+ row = sft_row(instr, user, out, source=source, rid=rid, task=task, meta=meta)
152
+ writer.write(row)
153
+ stats["written"] += 1
154
+ return True
155
+
156
+ # ——————————— dataset processors ———————————
157
+
158
+ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
159
+ count = 0
160
+ written = 0
161
+ for i, obj in enumerate(_iter_json_or_jsonl(path), start=1):
162
+ try:
163
+ instr_raw = obj.get("instruction") or "Answer the patient's question like a clinician. Be concise and safe."
164
+ user_raw = obj.get("input") or ""
165
+ out_raw = obj.get("output") or ""
166
+
167
+ # Ensure we have string values
168
+ instr = str(instr_raw).strip()
169
+ user = str(user_raw).strip()
170
+ out = str(out_raw).strip()
171
+ rid = _hash_id(source, i, len(user), len(out))
172
+ except Exception as e:
173
+ logger.warning(f"[PROC] {source} error processing item {i}: {e}, item: {obj}")
174
+ continue
175
+
176
+ try:
177
+ instr, user, out, applied = _apply_aug(instr, user, out, source, opts, paraphraser, stats)
178
+
179
+ # 1) ALWAYS write the original (cleaned/style-standardised only)
180
+ # Optional consistency spot-check (cheap)
181
+ if not A.consistency_ok(user, out, opts.get("consistency_check_ratio", 0.0), paraphraser):
182
+ stats["consistency_failed"] += 1
183
+ # keep the sample but tag it
184
+ applied.append("consistency_flag")
185
+
186
+ # 2) If expansion is enabled, add augmented copies
187
+ _commit_row(writer, source, rid, "medical_dialogue", instr, user, out, opts, stats, ["base"] + applied, dedupe_seen=dedupe_seen)
188
+ # Add augmented copies if expand
189
+ if opts.get("expand", True):
190
+ for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
191
+ rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
192
+ _commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
193
+
194
+ # Increment count only on success
195
+ count += 1
196
+ except Exception as e:
197
+ logger.warning(f"[PROC] {source} error in processing/augmentation for item {i}: {e}")
198
+ continue
199
+ if sample_limit and count >= sample_limit:
200
+ break
201
+ if cb and i % 1000 == 0:
202
+ cb(min(0.9, 0.05 + i/200000), f"{source}: processed {i} rows")
203
+ if cb:
204
+ cb(0.92, f"{source} done ({count})")
205
+ logger.info(f"[PROC] {source} done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
206
+ return count
207
+
208
+ def _proc_pubmedqa_l(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
209
+ with open(path, "r", encoding="utf-8") as f:
210
+ data = json.load(f)
211
+ count = 0
212
+ for k, v in data.items():
213
+ try:
214
+ q_raw = v.get("QUESTION") or ""
215
+ ctx_list = v.get("CONTEXTS") or []
216
+ long_ans_raw = v.get("LONG_ANSWER") or ""
217
+ final_raw = v.get("final_decision") or ""
218
+
219
+ # Ensure we have string values
220
+ q = str(q_raw).strip() if q_raw else ""
221
+ if isinstance(ctx_list, list):
222
+ context = "\n".join(str(ctx) for ctx in ctx_list).strip()
223
+ else:
224
+ context = str(ctx_list).strip()
225
+ long_ans = str(long_ans_raw).strip() if long_ans_raw else ""
226
+ final = str(final_raw).strip() if final_raw else ""
227
+ except Exception as e:
228
+ logger.warning(f"[PROC] pubmedqa_l error processing item {k}: {e}, item: {v}")
229
+ continue
230
+
231
+ try:
232
+ instr = "Answer the biomedical question using the provided context. Include a concise rationale if possible."
233
+ user = f"Question: {q}\n\nContext:\n{context}" if context else f"Question: {q}"
234
+ out = long_ans if long_ans else final
235
+ rid = str(k)
236
+
237
+ instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_l", opts, paraphraser, stats)
238
+ _commit_row(writer, "pubmedqa_l", rid, "biomedical_qa", instr, user, out, opts, stats, applied,
239
+ extra_meta={"year": v.get("YEAR"), "meshes": v.get("MESHES"), "labels": v.get("LABELS")}, dedupe_seen=dedupe_seen)
240
+ if opts.get("expand", True):
241
+ for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
242
+ rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
243
+ _commit_row(writer, "pubmedqa_l", rid_aug, "biomedical_qa",
244
+ instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
245
+
246
+ # Increment count only on success
247
+ count += 1
248
+ except Exception as e:
249
+ logger.warning(f"[PROC] pubmedqa_l error in processing/augmentation for item {k}: {e}")
250
+ continue
251
+ if sample_limit and count >= sample_limit:
252
+ break
253
+ if cb and count % 1000 == 0:
254
+ cb(min(0.9, 0.05 + count/60000), f"pubmedqa_l processed {count}")
255
+ if cb:
256
+ cb(0.93, f"pubmedqa_l done ({count})")
257
+ logger.info(f"[PROC] pubmedqa_l done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
258
+ return count
259
+
260
+ def _proc_pubmedqa_u(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
261
+ with open(path, "r", encoding="utf-8") as f:
262
+ data = json.load(f)
263
+ count = 0
264
+ for k, v in data.items():
265
+ try:
266
+ q_raw = v.get("QUESTION") or ""
267
+ ctx_list = v.get("CONTEXTS") or []
268
+
269
+ # Ensure we have string values
270
+ q = str(q_raw).strip() if q_raw else ""
271
+ if isinstance(ctx_list, list):
272
+ context = "\n".join(str(ctx) for ctx in ctx_list).strip()
273
+ else:
274
+ context = str(ctx_list).strip()
275
+ except Exception as e:
276
+ logger.warning(f"[PROC] pubmedqa_u error processing item {k}: {e}, item: {v}")
277
+ continue
278
+
279
+ try:
280
+ instr = "Rewrite the context into a succinct note, then answer the question. If unknown, say 'insufficient evidence'."
281
+ user = f"Question: {q}\n\nContext:\n{context}" if context else f"Question: {q}"
282
+ out = "" # unlabeled
283
+ rid = str(k)
284
+
285
+ # Optional KD/distillation for a small fraction
286
+ if opts.get("distill_fraction", 0.0) > 0.0 and random.random() < float(opts["distill_fraction"]):
287
+ prompt = f"{instr}\n\n{user}\n\nAnswer briefly and safely."
288
+ guess = paraphraser.paraphrase(prompt, difficulty="hard") # cheap single call
289
+ if guess and len(guess) < 2000:
290
+ out = guess.strip()
291
+
292
+ instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_u", opts, paraphraser, stats)
293
+ _commit_row(writer, "pubmedqa_u", str(k), "biomedical_qa_unlabeled", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen)
294
+ if opts.get("expand", True):
295
+ for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
296
+ rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
297
+ _commit_row(writer, "pubmedqa_u", rid_aug, "biomedical_qa",
298
+ instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
299
+
300
+ # Increment count only on success
301
+ count += 1
302
+ except Exception as e:
303
+ logger.warning(f"[PROC] pubmedqa_u error in processing/augmentation for item {k}: {e}")
304
+ continue
305
+ if sample_limit and count >= sample_limit:
306
+ break
307
+ if cb and count % 2000 == 0:
308
+ cb(min(0.9, 0.05 + count/80000), f"pubmedqa_u processed {count}")
309
+ if cb:
310
+ cb(0.94, f"pubmedqa_u done ({count})")
311
+ logger.info(f"[PROC] pubmedqa_u done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
312
+ return count
313
+
314
+ def _proc_pubmedqa_map(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
315
+ with open(path, "r", encoding="utf-8") as f:
316
+ obj = json.load(f)
317
+
318
+ # Log the structure for debugging
319
+ logger.info(f"[PROC] pubmedqa_map data type: {type(obj)}")
320
+ if isinstance(obj, dict):
321
+ logger.info(f"[PROC] pubmedqa_map dict keys: {list(obj.keys())}")
322
+ if len(obj) > 0:
323
+ sample_key = next(iter(obj.keys()))
324
+ sample_value = obj[sample_key]
325
+ logger.info(f"[PROC] pubmedqa_map sample value type: {type(sample_value)}")
326
+ if isinstance(sample_value, dict):
327
+ logger.info(f"[PROC] pubmedqa_map sample value keys: {list(sample_value.keys())}")
328
+
329
+ # Iteration of items
330
+ def iter_items():
331
+ try:
332
+ if isinstance(obj, list):
333
+ for it in obj:
334
+ if isinstance(it, dict):
335
+ yield it
336
+ else:
337
+ logger.warning(f"[PROC] pubmedqa_map skipping non-dict list item: {type(it)}")
338
+ elif isinstance(obj, dict):
339
+ qs, cs, ans = obj.get("question"), obj.get("context"), obj.get("answer")
340
+ if isinstance(qs, list) and isinstance(cs, list) and isinstance(ans, list):
341
+ for i in range(min(len(qs), len(cs), len(ans))):
342
+ yield {"question": qs[i], "context": cs[i], "answer": ans[i]}
343
+ else:
344
+ # Handle case where values might be dictionaries or other objects
345
+ for k, v in obj.items():
346
+ if isinstance(v, dict):
347
+ # If v is a dict, ensure it has the expected structure
348
+ if "question" in v and "context" in v and "answer" in v:
349
+ yield v
350
+ else:
351
+ # Try to map the keys to expected structure
352
+ yield {
353
+ "question": v.get("question") or v.get("QUESTION") or str(k),
354
+ "context": v.get("context") or v.get("CONTEXT") or "",
355
+ "answer": v.get("answer") or v.get("ANSWER") or ""
356
+ }
357
+ else:
358
+ # If v is not a dict, create a simple structure
359
+ yield {"question": str(k), "context": str(v) if v else "", "answer": ""}
360
+ else:
361
+ logger.warning(f"[PROC] pubmedqa_map unexpected data type: {type(obj)}")
362
+ except Exception as e:
363
+ logger.error(f"[PROC] pubmedqa_map error in iter_items: {e}")
364
+ return
365
+
366
+ count = 0
367
+ for i, v in enumerate(iter_items(), start=1):
368
+ try:
369
+ # Ensure we have string values, convert if necessary
370
+ q_raw = v.get("question") or ""
371
+ c_raw = v.get("context") or ""
372
+ a_raw = v.get("answer") or ""
373
+
374
+ # Convert to string if not already
375
+ q = str(q_raw).strip() if q_raw else ""
376
+ c = str(c_raw).strip() if c_raw else ""
377
+ a = str(a_raw).strip() if a_raw else ""
378
+
379
+ instr = "Answer the biomedical question based on the context. Justify briefly."
380
+ user = f"Question: {q}\n\nContext:\n{c}" if c else f"Question: {q}"
381
+ out = a
382
+ rid = _hash_id("pubmedqa_map", i, len(q))
383
+
384
+ # Process the item
385
+ instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_map", opts, paraphraser, stats)
386
+ _commit_row(writer, "pubmedqa_map", rid, "biomedical_qa", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen)
387
+
388
+ # Handle expansion if enabled
389
+ if opts.get("expand", True):
390
+ for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
391
+ rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
392
+ _commit_row(writer, "pubmedqa_map", rid_aug, "biomedical_qa",
393
+ instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
394
+
395
+ # Increment count only on success
396
+ count += 1
397
+
398
+ except Exception as e:
399
+ logger.warning(f"[PROC] pubmedqa_map error processing item {i}: {e}, item: {v}")
400
+ continue
401
+
402
+ # Check sample limit
403
+ if sample_limit and count >= sample_limit:
404
+ break
405
+ if cb and i % 2000 == 0:
406
+ cb(min(0.9, 0.05 + i/120000), f"pubmedqa_map processed {i}")
407
+
408
+ if cb:
409
+ cb(0.95, f"pubmedqa_map done ({count})")
410
+ logger.info(f"[PROC] pubmedqa_map done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
411
+ return count
utils/rag.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG-specific dataset processor
2
+ import json
3
+ import logging
4
+ import hashlib
5
+ import random
6
+ from typing import Dict, List, Tuple, Optional, Callable
7
+
8
+ from utils.schema import sft_row
9
+ from utils.llm import NvidiaClient, KeyRotator
10
+
11
+ # Logger
12
+ logger = logging.getLogger("rag_processor")
13
+ if not logger.handlers:
14
+ logger.setLevel(logging.INFO)
15
+ logger.addHandler(logging.StreamHandler())
16
+
17
+ def _hash_id(*parts) -> str:
18
+ """Generate a hash ID for RAG entries"""
19
+ h = hashlib.sha256()
20
+ for p in parts:
21
+ h.update(str(p).encode("utf-8"))
22
+ return h.hexdigest()[:16]
23
+
24
+ def _iter_json_or_jsonl(path: str):
25
+ """Iterate over JSON or JSONL files"""
26
+ with open(path, "r", encoding="utf-8") as f:
27
+ first = f.read(1)
28
+ f.seek(0)
29
+ if first == "[":
30
+ data = json.load(f)
31
+ for obj in data:
32
+ yield obj
33
+ else:
34
+ for line in f:
35
+ line = line.strip()
36
+ if line:
37
+ yield json.loads(line)
38
+
39
+ class RAGProcessor:
40
+ """Processes medical datasets into RAG-specific QCA (Question, Context, Answer) format"""
41
+
42
+ def __init__(self, nvidia_model: str):
43
+ self.nvidia_client = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
44
+
45
+ def clean_conversational_content(self, text: str) -> str:
46
+ """Remove conversational elements and non-medical information using NVIDIA model"""
47
+ if not text or len(text.strip()) < 10:
48
+ return text
49
+
50
+ prompt = f"""
51
+ You are a medical data cleaning expert. Clean the following text by:
52
+ 1. Remove conversational elements (greetings, pleasantries)
53
+ 2. Remove non-medical small talk and social interactions
54
+ 3. Keep only medically relevant information
55
+ 4. Preserve clinical facts, symptoms, diagnoses, treatments, and medical advice
56
+ 5. Maintain professional medical language
57
+ 6. Return only cleaned medical content, only plain text, no special characters, or formatting.
58
+
59
+ Text to clean:
60
+ {text}
61
+
62
+ Cleaned medical content:"""
63
+
64
+ try:
65
+ cleaned = self.nvidia_client.generate(
66
+ prompt,
67
+ temperature=0.1,
68
+ max_tokens=min(1000, len(text) + 200)
69
+ )
70
+ return cleaned.strip() if cleaned else text
71
+ except Exception as e:
72
+ logger.warning(f"[RAG] Error cleaning text: {e}")
73
+ return text
74
+
75
+ def generate_context_from_qa(self, question: str, answer: str) -> str:
76
+ """Generate synthetic context from question and answer using NVIDIA model"""
77
+ if not question or not answer:
78
+ return ""
79
+
80
+ prompt = f"""You are a medical knowledge expert. Given a medical question and its answer, generate a brief relevant medical context that would help someone understand the answer better. Write about 2 sentences that provide relevant background information. Use only plain text without any formatting or symbols.
81
+
82
+ Question: {question}
83
+
84
+ Answer: {answer}
85
+
86
+ Generate a concise medical context:"""
87
+
88
+ try:
89
+ context = self.nvidia_client.generate(
90
+ prompt,
91
+ temperature=0.2,
92
+ max_tokens=200
93
+ )
94
+ return context.strip() if context else ""
95
+ except Exception as e:
96
+ logger.warning(f"[RAG] Error generating context: {e}")
97
+ return ""
98
+
99
+ def convert_to_qca_format(self, instruction: str, user_input: str, output: str) -> Tuple[str, str, str]:
100
+ """Convert SFT format to QCA (Question, Context, Answer) format"""
101
+ # Clean the content to remove conversational elements
102
+ cleaned_input = self.clean_conversational_content(user_input)
103
+ cleaned_output = self.clean_conversational_content(output)
104
+
105
+ # Extract question from user input
106
+ question = self.extract_question(cleaned_input)
107
+
108
+ # Extract or generate context
109
+ context = self.extract_context(cleaned_input, question, cleaned_output)
110
+
111
+ # Clean answer
112
+ answer = cleaned_output
113
+
114
+ return question, context, answer
115
+
116
+ def extract_question(self, user_input: str) -> str:
117
+ """Extract the main question from user input"""
118
+ if not user_input:
119
+ return ""
120
+
121
+ # Try to identify question patterns
122
+ lines = user_input.split('\n')
123
+ for line in lines:
124
+ line = line.strip()
125
+ if line.startswith('Question:') or line.startswith('Q:'):
126
+ return line.replace('Question:', '').replace('Q:', '').strip()
127
+ elif '?' in line and len(line) > 10:
128
+ return line
129
+
130
+ # If no clear question found, use the first meaningful line
131
+ for line in lines:
132
+ line = line.strip()
133
+ if len(line) > 10:
134
+ return line
135
+
136
+ return user_input
137
+
138
+ def extract_context(self, user_input: str, question: str, answer: str) -> str:
139
+ """Extract context from user input or generate synthetic context"""
140
+ # Look for context in the original input
141
+ context_candidates = []
142
+ lines = user_input.split('\n')
143
+
144
+ for line in lines:
145
+ line = line.strip()
146
+ if (line.startswith('Context:') or
147
+ line.startswith('Background:') or
148
+ line.startswith('Information:') or
149
+ (len(line) > 50 and not line.startswith('Question:') and '?' not in line)):
150
+ context_candidates.append(line)
151
+
152
+ if context_candidates:
153
+ # Clean and combine context candidates
154
+ context = ' '.join(context_candidates)
155
+ context = self.clean_conversational_content(context)
156
+ if len(context) > 20: # Ensure we have meaningful context
157
+ return context
158
+
159
+ # Generate synthetic context if none found
160
+ if question and answer:
161
+ synthetic_context = self.generate_context_from_qa(question, answer)
162
+ if synthetic_context:
163
+ return synthetic_context
164
+
165
+ return ""
166
+
167
+ def process_medical_dialog(self, source: str, path: str, writer, sample_limit: Optional[int],
168
+ stats: Dict, progress_cb: Optional[Callable], dedupe_seen: set = None) -> int:
169
+ """Process medical dialogue datasets into RAG format"""
170
+ count = 0
171
+ written = 0
172
+
173
+ for i, obj in enumerate(_iter_json_or_jsonl(path), start=1):
174
+ try:
175
+ instr_raw = obj.get("instruction") or "Answer the medical question based on the provided context."
176
+ user_raw = obj.get("input") or ""
177
+ out_raw = obj.get("output") or ""
178
+
179
+ instr = str(instr_raw).strip()
180
+ user = str(user_raw).strip()
181
+ out = str(out_raw).strip()
182
+ rid = _hash_id(source, i, len(user), len(out))
183
+
184
+ # Convert to QCA format
185
+ question, context, answer = self.convert_to_qca_format(instr, user, out)
186
+
187
+ if not question or not answer:
188
+ continue
189
+
190
+ # Create RAG-specific instruction
191
+ rag_instruction = "Answer the medical question based on the provided context. If the context is insufficient, provide the best available medical information."
192
+
193
+ # Format user input as QCA
194
+ if context:
195
+ rag_user = f"Question: {question}\n\nContext: {context}"
196
+ else:
197
+ rag_user = f"Question: {question}"
198
+
199
+ # Commit the RAG-formatted row
200
+ if self._commit_rag_row(writer, source, rid, "rag_medical_qa",
201
+ rag_instruction, rag_user, answer,
202
+ stats, dedupe_seen=dedupe_seen):
203
+ written += 1
204
+
205
+ count += 1
206
+
207
+ except Exception as e:
208
+ logger.warning(f"[RAG] {source} error processing item {i}: {e}")
209
+ continue
210
+
211
+ if sample_limit and count >= sample_limit:
212
+ break
213
+ if progress_cb and i % 1000 == 0:
214
+ progress_cb(min(0.9, 0.05 + i/200000), f"{source}: processed {i} rows for RAG")
215
+
216
+ if progress_cb:
217
+ progress_cb(0.92, f"{source} RAG processing done ({count})")
218
+
219
+ logger.info(f"[RAG] {source} RAG processing done count={count} written={written}")
220
+ return count
221
+
222
+ def process_pubmedqa(self, source: str, path: str, writer, sample_limit: Optional[int],
223
+ stats: Dict, progress_cb: Optional[Callable], dedupe_seen: set = None) -> int:
224
+ """Process PubMedQA datasets into RAG format"""
225
+ with open(path, "r", encoding="utf-8") as f:
226
+ data = json.load(f)
227
+
228
+ count = 0
229
+ written = 0
230
+
231
+ for k, v in data.items():
232
+ try:
233
+ q_raw = v.get("QUESTION") or ""
234
+ ctx_list = v.get("CONTEXTS") or []
235
+ long_ans_raw = v.get("LONG_ANSWER") or ""
236
+ final_raw = v.get("final_decision") or ""
237
+
238
+ question = str(q_raw).strip() if q_raw else ""
239
+ if isinstance(ctx_list, list):
240
+ context = "\n".join(str(ctx) for ctx in ctx_list).strip()
241
+ else:
242
+ context = str(ctx_list).strip()
243
+ answer = str(long_ans_raw).strip() if long_ans_raw else str(final_raw).strip()
244
+
245
+ if not question or not answer:
246
+ continue
247
+
248
+ # Clean the content
249
+ question = self.clean_conversational_content(question)
250
+ context = self.clean_conversational_content(context)
251
+ answer = self.clean_conversational_content(answer)
252
+
253
+ # Generate context if missing
254
+ if not context:
255
+ context = self.generate_context_from_qa(question, answer)
256
+
257
+ rid = str(k)
258
+ rag_instruction = "Answer the biomedical question based on the provided context."
259
+
260
+ if context:
261
+ rag_user = f"Question: {question}\n\nContext: {context}"
262
+ else:
263
+ rag_user = f"Question: {question}"
264
+
265
+ # Commit the RAG-formatted row
266
+ if self._commit_rag_row(writer, source, rid, "rag_biomedical_qa",
267
+ rag_instruction, rag_user, answer,
268
+ stats, dedupe_seen=dedupe_seen):
269
+ written += 1
270
+
271
+ count += 1
272
+
273
+ except Exception as e:
274
+ logger.warning(f"[RAG] {source} error processing item {k}: {e}")
275
+ continue
276
+
277
+ if sample_limit and count >= sample_limit:
278
+ break
279
+ if progress_cb and count % 1000 == 0:
280
+ progress_cb(min(0.9, 0.05 + count/60000), f"{source} RAG processed {count}")
281
+
282
+ if progress_cb:
283
+ progress_cb(0.93, f"{source} RAG processing done ({count})")
284
+
285
+ logger.info(f"[RAG] {source} RAG processing done count={count} written={written}")
286
+ return count
287
+
288
+ def _commit_rag_row(self, writer, source: str, rid: str, task: str,
289
+ instruction: str, user_input: str, output: str,
290
+ stats: Dict, dedupe_seen: set = None) -> bool:
291
+ """Commit a RAG-formatted row to the writer"""
292
+ # Simple deduplication based on content hash
293
+ if dedupe_seen is not None:
294
+ content_hash = hashlib.md5(f"{user_input}{output}".encode()).hexdigest()
295
+ if content_hash in dedupe_seen:
296
+ stats["dedup_skipped"] = stats.get("dedup_skipped", 0) + 1
297
+ return False
298
+ dedupe_seen.add(content_hash)
299
+
300
+ meta = {"rag_processing": True, "format": "qca"}
301
+ row = sft_row(instruction, user_input, output, source=source, rid=rid, task=task, meta=meta)
302
+ writer.write(row)
303
+ stats["written"] = stats.get("written", 0) + 1
304
+ return True
305
+
306
+ def process_file_into_rag(
307
+ dataset_key: str,
308
+ input_path: str,
309
+ writer,
310
+ nvidia_model: str,
311
+ sample_limit: Optional[int],
312
+ seed: int,
313
+ progress_cb: Optional[Callable[[float, str], None]]
314
+ ) -> Tuple[int, Dict]:
315
+ """Main entry point for RAG processing"""
316
+ random.seed(seed)
317
+ stats = {
318
+ "written": 0,
319
+ "dedup_skipped": 0
320
+ }
321
+
322
+ logger.info(f"[RAG] Begin RAG processing dataset={dataset_key} sample_limit={sample_limit}")
323
+
324
+ # Initialize RAG processor
325
+ rag_processor = RAGProcessor(nvidia_model)
326
+ dedupe_seen = set()
327
+
328
+ key = dataset_key.lower()
329
+ if key in ("healthcaremagic", "icliniq"):
330
+ count = rag_processor.process_medical_dialog(
331
+ source=key, path=input_path, writer=writer,
332
+ sample_limit=sample_limit, stats=stats,
333
+ progress_cb=progress_cb, dedupe_seen=dedupe_seen
334
+ )
335
+ elif key in ("pubmedqa_l", "pubmedqa_u", "pubmedqa_map"):
336
+ count = rag_processor.process_pubmedqa(
337
+ source=key, path=input_path, writer=writer,
338
+ sample_limit=sample_limit, stats=stats,
339
+ progress_cb=progress_cb, dedupe_seen=dedupe_seen
340
+ )
341
+ else:
342
+ raise ValueError(f"Unknown dataset for RAG processing: {dataset_key}")
343
+
344
+ logger.info(f"[RAG] End RAG processing dataset={dataset_key} stats={stats}")
345
+ return count, stats
utils/schema.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Centralized SFT writer (JSONL + CSV)
2
+ import csv
3
+ import orjson
4
+ from typing import Optional, Dict
5
+ import logging
6
+
7
+ # Logger
8
+ logger = logging.getLogger("schema")
9
+ if not logger.handlers:
10
+ logger.setLevel(logging.INFO)
11
+ logger.addHandler(logging.StreamHandler())
12
+
13
+ def sft_row(instruction: str, user_input: str, output: str, source: str, rid: str, task: str, meta: Optional[dict] = None):
14
+ return {
15
+ "source": source,
16
+ "id": rid,
17
+ "task": task,
18
+ "sft": {
19
+ "instruction": instruction,
20
+ "input": user_input,
21
+ "output": output
22
+ },
23
+ "meta": meta or {}
24
+ }
25
+
26
+ def is_valid_row(row: Dict, max_chars: int = 20000) -> bool:
27
+ s = row.get("sft", {})
28
+ instr = s.get("instruction", "")
29
+ inp = s.get("input", "")
30
+ out = s.get("output", "")
31
+ # basic sanity: non-empty input OR output; cap extremes
32
+ if not (inp or out): return False
33
+ if any(len(x) > max_chars for x in (instr, inp, out)): return False
34
+ return True
35
+
36
+ class CentralisedWriter:
37
+ """Streams JSONL + CSV in parallel to stay memory-safe."""
38
+ def __init__(self, jsonl_path: str, csv_path: str):
39
+ self.jsonl_fp = open(jsonl_path, "wb")
40
+ self.csv_fp = open(csv_path, "w", newline="", encoding="utf-8")
41
+ self.csv_wr = csv.DictWriter(self.csv_fp, fieldnames=["instruction","input","output","source","id","task"])
42
+ self.csv_wr.writeheader()
43
+
44
+ def write(self, row: dict):
45
+ if not is_valid_row(row):
46
+ s = row.get("sft", {})
47
+ logger.warning(
48
+ f"[WRITER] Skipping invalid row id={row.get('id')} "
49
+ f"(len instr={len(s.get('instruction',''))}, input={len(s.get('input',''))}, output={len(s.get('output',''))})"
50
+ )
51
+ return
52
+ self.jsonl_fp.write(orjson.dumps(row))
53
+ self.jsonl_fp.write(b"\n")
54
+ s = row["sft"]
55
+ self.csv_wr.writerow({
56
+ "instruction": s.get("instruction",""),
57
+ "input": s.get("input",""),
58
+ "output": s.get("output",""),
59
+ "source": row.get("source",""),
60
+ "id": row.get("id",""),
61
+ "task": row.get("task","")
62
+ })
63
+
64
+ def close(self):
65
+ try:
66
+ self.jsonl_fp.close()
67
+ finally:
68
+ self.csv_fp.close()
utils/token.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GCS credential token refresher
2
+ import os, json, logging
3
+ from typing import Optional
4
+ from google.oauth2.credentials import Credentials
5
+ from google_auth_oauthlib.flow import Flow
6
+ from google.auth.transport.requests import Request
7
+
8
+ logger = logging.getLogger("token")
9
+ if not logger.handlers:
10
+ logger.setLevel(logging.INFO)
11
+ handler = logging.StreamHandler()
12
+ logger.addHandler(handler)
13
+
14
+ SCOPES = ["https://www.googleapis.com/auth/drive.file"]
15
+ TOKEN_FILE = os.getenv("GDRIVE_TOKEN_FILE", "cache/secrets/gdrive_token.json")
16
+
17
+ def _load_oauth_client_web():
18
+ cfg_env = os.getenv("GDRIVE_CREDENTIALS_JSON")
19
+ if not cfg_env:
20
+ return None
21
+ try:
22
+ cfg = json.loads(cfg_env)
23
+ return cfg.get("web")
24
+ except Exception as e:
25
+ logger.error(f"❌ Failed to parse GDRIVE_CREDENTIALS_JSON: {e}")
26
+ return None
27
+
28
+ def _ensure_dirs():
29
+ base = os.path.dirname(TOKEN_FILE)
30
+ if base and not os.path.exists(base):
31
+ os.makedirs(base, exist_ok=True)
32
+
33
+ def get_credentials() -> Optional[Credentials]:
34
+ # 1) Token file
35
+ if os.path.exists(TOKEN_FILE):
36
+ try:
37
+ with open(TOKEN_FILE, "r", encoding="utf-8") as f:
38
+ data = json.load(f)
39
+ creds = Credentials.from_authorized_user_info(data, scopes=SCOPES)
40
+ if creds and creds.expired and creds.refresh_token:
41
+ creds.refresh(Request())
42
+ logger.info("🔄 Refreshed access token from token file")
43
+ return creds
44
+ except Exception as e:
45
+ logger.warning(f"⚠️ Failed to load token file: {e}")
46
+
47
+ # 2) Refresh token in env
48
+ refresh = os.getenv("GDRIVE_REFRESH_TOKEN")
49
+ web = _load_oauth_client_web()
50
+ if refresh and web:
51
+ creds = Credentials(
52
+ None,
53
+ refresh_token=refresh,
54
+ token_uri="https://oauth2.googleapis.com/token",
55
+ client_id=web.get("client_id"),
56
+ client_secret=web.get("client_secret"),
57
+ scopes=SCOPES,
58
+ )
59
+ if creds and (creds.expired or not creds.valid):
60
+ try:
61
+ creds.refresh(Request())
62
+ logger.info("🔄 Refreshed access token from env refresh token")
63
+ except Exception as e:
64
+ logger.warning(f"⚠️ Refresh with env token failed: {e}")
65
+ return creds
66
+
67
+ # 3) Nothing available
68
+ return None
69
+
70
+ def build_auth_url(redirect_uri: str) -> str:
71
+ web = _load_oauth_client_web()
72
+ if not web:
73
+ raise RuntimeError("GDRIVE_CREDENTIALS_JSON missing or invalid ('web' section required)")
74
+ flow = Flow.from_client_config({"web": web}, scopes=SCOPES, redirect_uri=redirect_uri)
75
+ auth_url, _ = flow.authorization_url(
76
+ prompt="consent",
77
+ access_type="offline",
78
+ include_granted_scopes="true"
79
+ )
80
+ return auth_url
81
+
82
+ def exchange_code(code: str, redirect_uri: str) -> Credentials:
83
+ web = _load_oauth_client_web()
84
+ if not web:
85
+ raise RuntimeError("GDRIVE_CREDENTIALS_JSON missing or invalid ('web' section required)")
86
+ flow = Flow.from_client_config({"web": web}, scopes=SCOPES, redirect_uri=redirect_uri)
87
+ flow.fetch_token(code=code)
88
+ creds: Credentials = flow.credentials
89
+
90
+ info = {
91
+ "token": creds.token,
92
+ "refresh_token": creds.refresh_token,
93
+ "token_uri": "https://oauth2.googleapis.com/token",
94
+ "client_id": web.get("client_id"),
95
+ "client_secret": web.get("client_secret"),
96
+ "scopes": SCOPES,
97
+ }
98
+ _ensure_dirs()
99
+ with open(TOKEN_FILE, "w", encoding="utf-8") as f:
100
+ json.dump(info, f)
101
+ logger.info("✅ Saved Google refresh token to %s", TOKEN_FILE)
102
+
103
+ # also set env for current process
104
+ if creds.refresh_token:
105
+ os.environ["GDRIVE_REFRESH_TOKEN"] = creds.refresh_token
106
+
107
+ return creds