Spaces:
Sleeping
Sleeping
Commit
·
80cb919
0
Parent(s):
Upd 14/9
Browse files- .DS_Store +0 -0
- .gitattributes +35 -0
- .gitignore +4 -0
- DATA_PROCESSING.md +250 -0
- Dockerfile +31 -0
- LICENSE.txt +201 -0
- README.md +32 -0
- REQUEST.md +156 -0
- app.py +423 -0
- mount_drive.py +9 -0
- requirements.txt +13 -0
- utils/ __init__.py +22 -0
- utils/.DS_Store +0 -0
- utils/augment.py +105 -0
- utils/datasets.py +66 -0
- utils/drive_saver.py +88 -0
- utils/llm.py +186 -0
- utils/processor.py +411 -0
- utils/rag.py +345 -0
- utils/schema.py +68 -0
- utils/token.py +107 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
client1.json
|
| 3 |
+
client2.json
|
| 4 |
+
medai.json
|
DATA_PROCESSING.md
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📊 MedAI Data Processing Techniques
|
| 2 |
+
|
| 3 |
+
This document comprehensively outlines all the data processing techniques implemented in the MedAI Processing project for augmenting and centrally processing medical datasets for LLM fine-tuning.
|
| 4 |
+
|
| 5 |
+
## 🎯 Project Overview
|
| 6 |
+
|
| 7 |
+
The MedAI Processing system transforms raw medical datasets into a **centralized fine-tuning format** (JSONL + CSV) with comprehensive data augmentation capabilities. The system processes multiple medical dataset types and applies various enhancement techniques to improve data quality and diversity.
|
| 8 |
+
|
| 9 |
+
## 🏗️ System Architecture
|
| 10 |
+
|
| 11 |
+
### Core Components
|
| 12 |
+
- **FastAPI Web Service**: RESTful API for dataset processing
|
| 13 |
+
- **Multi-LLM Rotator**: NVIDIA API + Google Gemini integration
|
| 14 |
+
- **Centralized Writer**: Parallel JSONL + CSV output generation
|
| 15 |
+
- **Google Drive Integration**: Automated artifact storage
|
| 16 |
+
- **Progress Monitoring**: Real-time job status tracking
|
| 17 |
+
|
| 18 |
+
### Supported Datasets
|
| 19 |
+
1. **HealthCareMagic** (100k medical dialogues)
|
| 20 |
+
2. **iCliniq** (10k medical consultations)
|
| 21 |
+
3. **PubMedQA-Labelled** (biomedical Q&A with answers)
|
| 22 |
+
4. **PubMedQA-Unlabelled** (biomedical Q&A without answers)
|
| 23 |
+
5. **PubMedQA-Map** (biomedical Q&A mapping format)
|
| 24 |
+
|
| 25 |
+
## 🔧 Data Processing Pipeline
|
| 26 |
+
|
| 27 |
+
### 1. Data Ingestion & Download
|
| 28 |
+
- **Hugging Face Hub Integration**: Automatic dataset downloading
|
| 29 |
+
- **Format Detection**: JSON/JSONL auto-detection and parsing
|
| 30 |
+
- **Caching System**: Local storage with symlink optimization
|
| 31 |
+
|
| 32 |
+
### 2. Data Cleaning & Preprocessing
|
| 33 |
+
|
| 34 |
+
#### Text Normalization
|
| 35 |
+
- **Unicode Fixing**: `ftfy` library for text encoding issues
|
| 36 |
+
- **Whitespace Standardization**: Consistent spacing and line breaks
|
| 37 |
+
- **Quote Canonicalization**: Standard quote character conversion
|
| 38 |
+
- **Terminal Punctuation**: Ensures proper sentence endings
|
| 39 |
+
|
| 40 |
+
#### Content Sanitization
|
| 41 |
+
- **Length Capping**: Configurable maximum character limits (default: 5000)
|
| 42 |
+
- **Language Detection**: English language validation using `langid`
|
| 43 |
+
- **Content Truncation**: Smart sentence boundary cutting for long texts
|
| 44 |
+
|
| 45 |
+
### 3. Data Augmentation Techniques
|
| 46 |
+
|
| 47 |
+
#### LLM-Based Paraphrasing
|
| 48 |
+
- **Multi-Model Rotation**: NVIDIA API (primary) + Gemini (fallback)
|
| 49 |
+
- **Difficulty Levels**: Easy vs. Hard paraphrasing modes
|
| 50 |
+
- **Medical Context Preservation**: Maintains clinical terminology accuracy
|
| 51 |
+
- **Configurable Ratios**: User-defined augmentation percentages (0.0-1.0)
|
| 52 |
+
|
| 53 |
+
#### Back-Translation Augmentation
|
| 54 |
+
- **Multi-Language Support**: German as intermediate language
|
| 55 |
+
- **Meaning Preservation**: Maintains semantic accuracy through translation cycles
|
| 56 |
+
- **Fallback Mechanisms**: Automatic retry with alternative models
|
| 57 |
+
- **Quality Control**: Length and content validation
|
| 58 |
+
|
| 59 |
+
#### Style Standardization
|
| 60 |
+
- **Clinical Voice Enforcement**: Neutral, professional medical tone
|
| 61 |
+
- **Absolute Language Removal**: Replaces guarantees with probabilistic language
|
| 62 |
+
- **Forum Sign-off Removal**: Eliminates informal communication patterns
|
| 63 |
+
- **Consistent Punctuation**: Standardized sentence structure
|
| 64 |
+
|
| 65 |
+
### 4. Data Quality Assurance
|
| 66 |
+
|
| 67 |
+
#### De-identification (PHI Removal)
|
| 68 |
+
- **Email Redaction**: `[REDACTED_EMAIL]` placeholder
|
| 69 |
+
- **Phone Number Masking**: `[REDACTED_PHONE]` placeholder
|
| 70 |
+
- **URL/IP Address Removal**: `[REDACTED_URL]` and `[REDACTED_IP]` placeholders
|
| 71 |
+
- **Configurable Privacy**: Optional PHI removal per dataset
|
| 72 |
+
|
| 73 |
+
#### Deduplication
|
| 74 |
+
- **Fingerprinting Algorithm**: MD5-based content hashing
|
| 75 |
+
- **Multi-Field Matching**: Instruction + Input + Output combination
|
| 76 |
+
- **Normalized Comparison**: Case-insensitive, whitespace-normalized matching
|
| 77 |
+
- **Performance Optimized**: In-memory set-based deduplication
|
| 78 |
+
|
| 79 |
+
#### Consistency Validation
|
| 80 |
+
- **LLM-Based QA Check**: Automated answer validation against context
|
| 81 |
+
- **Configurable Sampling**: Ratio-based consistency checking (e.g., 0.01)
|
| 82 |
+
- **Medical Safety Validation**: Ensures clinical accuracy and safety
|
| 83 |
+
- **Failure Tagging**: Marks samples with consistency issues
|
| 84 |
+
|
| 85 |
+
### 5. Advanced Augmentation Features
|
| 86 |
+
|
| 87 |
+
#### Knowledge Distillation
|
| 88 |
+
- **Pseudo-Label Generation**: Creates labels for unlabeled data
|
| 89 |
+
- **Fractional Processing**: Configurable percentage for distillation
|
| 90 |
+
- **Single-Prompt Approach**: Efficient single LLM call per sample
|
| 91 |
+
- **Length Control**: Maintains reasonable output lengths
|
| 92 |
+
|
| 93 |
+
#### Multi-Variant Generation
|
| 94 |
+
- **Configurable Counts**: 1-3 augmented variants per sample
|
| 95 |
+
- **Tagged Augmentations**: Tracks applied augmentation techniques
|
| 96 |
+
- **Original Preservation**: Always maintains base sample
|
| 97 |
+
- **Randomized IDs**: Unique identifiers for augmented variants
|
| 98 |
+
|
| 99 |
+
### 6. Output Generation & Storage
|
| 100 |
+
|
| 101 |
+
#### Centralized Format
|
| 102 |
+
- **SFT Schema**: Standardized Supervised Fine-Tuning format
|
| 103 |
+
- **Metadata Preservation**: Source, task type, and augmentation tags
|
| 104 |
+
- **Dual Output**: Simultaneous JSONL and CSV generation
|
| 105 |
+
- **Memory-Safe Streaming**: Handles large datasets efficiently
|
| 106 |
+
|
| 107 |
+
#### Storage Integration
|
| 108 |
+
- **Local Caching**: `cache/outputs/` directory storage
|
| 109 |
+
- **Google Drive Upload**: Automated cloud storage integration
|
| 110 |
+
- **Timestamped Naming**: Unique file identification
|
| 111 |
+
- **MIME Type Handling**: Proper content type specification
|
| 112 |
+
|
| 113 |
+
## ⚙️ Configuration Options
|
| 114 |
+
|
| 115 |
+
### Augmentation Parameters
|
| 116 |
+
```python
|
| 117 |
+
class AugmentOptions:
|
| 118 |
+
paraphrase_ratio: float = 0.0 # 0.0-1.0
|
| 119 |
+
paraphrase_outputs: bool = False # Augment model answers
|
| 120 |
+
backtranslate_ratio: float = 0.0 # 0.0-1.0
|
| 121 |
+
style_standardize: bool = True # Enforce clinical style
|
| 122 |
+
deidentify: bool = True # Remove PHI
|
| 123 |
+
dedupe: bool = True # Remove duplicates
|
| 124 |
+
max_chars: int = 5000 # Text length limit
|
| 125 |
+
consistency_check_ratio: float = 0.0 # 0.0-1.0
|
| 126 |
+
distill_fraction: float = 0.0 # 0.0-1.0 for unlabeled
|
| 127 |
+
expand: bool = True # Enable augmentation
|
| 128 |
+
max_aug_per_sample: int = 2 # 1-3 variants
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
### Processing Parameters
|
| 132 |
+
```python
|
| 133 |
+
class ProcessParams:
|
| 134 |
+
augment: AugmentOptions # Augmentation settings
|
| 135 |
+
sample_limit: Optional[int] = None # Dataset sampling
|
| 136 |
+
seed: int = 42 # Reproducibility
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
## 📈 Performance & Monitoring
|
| 140 |
+
|
| 141 |
+
### Progress Tracking
|
| 142 |
+
- **Real-time Updates**: Live progress percentage and status messages
|
| 143 |
+
- **Background Processing**: Non-blocking job execution
|
| 144 |
+
- **State Management**: Thread-safe status tracking
|
| 145 |
+
- **Error Handling**: Comprehensive exception logging
|
| 146 |
+
|
| 147 |
+
### Resource Management
|
| 148 |
+
- **API Key Rotation**: Automatic fallback between multiple API keys
|
| 149 |
+
- **Rate Limiting**: Configurable request throttling
|
| 150 |
+
- **Memory Optimization**: Streaming processing for large datasets
|
| 151 |
+
- **Concurrent Processing**: Background task execution
|
| 152 |
+
|
| 153 |
+
## 🔒 Security & Privacy
|
| 154 |
+
|
| 155 |
+
### Data Protection
|
| 156 |
+
- **PHI Removal**: Automatic sensitive information redaction
|
| 157 |
+
- **Secure Storage**: Google Drive integration with OAuth2
|
| 158 |
+
- **Access Control**: Environment-based API key management
|
| 159 |
+
- **Audit Logging**: Comprehensive processing logs
|
| 160 |
+
|
| 161 |
+
### API Security
|
| 162 |
+
- **OAuth2 Integration**: Google Drive authentication
|
| 163 |
+
- **Token Management**: Secure credential handling
|
| 164 |
+
- **Request Validation**: Pydantic model validation
|
| 165 |
+
- **Error Sanitization**: Safe error message handling
|
| 166 |
+
|
| 167 |
+
## 🚀 Usage Examples
|
| 168 |
+
|
| 169 |
+
### Basic Processing
|
| 170 |
+
```bash
|
| 171 |
+
# Process HealthCareMagic with default settings
|
| 172 |
+
curl -X POST \
|
| 173 |
+
-H "Content-Type: application/json" \
|
| 174 |
+
-d '{"augment": {"paraphrase_ratio": 0.1}}' \
|
| 175 |
+
https://binkhoale1812-medai-processing.hf.space/process/healthcaremagic
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
### Advanced Augmentation
|
| 179 |
+
```bash
|
| 180 |
+
# Process with comprehensive augmentation
|
| 181 |
+
curl -X POST \
|
| 182 |
+
-H "Content-Type: application/json" \
|
| 183 |
+
-d '{
|
| 184 |
+
"augment": {
|
| 185 |
+
"paraphrase_ratio": 0.2,
|
| 186 |
+
"backtranslate_ratio": 0.1,
|
| 187 |
+
"paraphrase_outputs": true,
|
| 188 |
+
"style_standardize": true,
|
| 189 |
+
"deidentify": true,
|
| 190 |
+
"dedupe": true,
|
| 191 |
+
"max_chars": 5000,
|
| 192 |
+
"consistency_check_ratio": 0.01,
|
| 193 |
+
"max_aug_per_sample": 3
|
| 194 |
+
},
|
| 195 |
+
"sample_limit": 1000,
|
| 196 |
+
"seed": 42
|
| 197 |
+
}' \
|
| 198 |
+
https://binkhoale1812-medai-processing.hf.space/process/icliniq
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
## 📊 Output Statistics
|
| 202 |
+
|
| 203 |
+
### Processing Metrics
|
| 204 |
+
- **Written Rows**: Total processed samples
|
| 205 |
+
- **Paraphrased Inputs**: Count of augmented user inputs
|
| 206 |
+
- **Paraphrased Outputs**: Count of augmented model responses
|
| 207 |
+
- **Back-translated**: Count of translation-augmented samples
|
| 208 |
+
- **Deduplication**: Count of skipped duplicate samples
|
| 209 |
+
- **Consistency Failures**: Count of validation failures
|
| 210 |
+
|
| 211 |
+
### File Outputs
|
| 212 |
+
- **JSONL Format**: Structured fine-tuning data with metadata
|
| 213 |
+
- **CSV Format**: Simplified tabular representation
|
| 214 |
+
- **Google Drive**: Cloud storage with automatic upload
|
| 215 |
+
- **Local Cache**: Persistent local storage
|
| 216 |
+
|
| 217 |
+
## 🔮 Future Enhancements
|
| 218 |
+
|
| 219 |
+
### Planned Features
|
| 220 |
+
- **Additional Dataset Support**: More medical dataset types
|
| 221 |
+
- **Advanced Augmentation**: More sophisticated LLM techniques
|
| 222 |
+
- **Quality Metrics**: Automated data quality scoring
|
| 223 |
+
- **Batch Processing**: Multiple dataset concurrent processing
|
| 224 |
+
- **Custom Schemas**: User-defined output formats
|
| 225 |
+
|
| 226 |
+
### Scalability Improvements
|
| 227 |
+
- **Distributed Processing**: Multi-node processing support
|
| 228 |
+
- **Streaming Augmentation**: Real-time data enhancement
|
| 229 |
+
- **Caching Optimization**: Improved performance and cost efficiency
|
| 230 |
+
- **API Rate Limiting**: Better resource management
|
| 231 |
+
|
| 232 |
+
## 📚 Technical Dependencies
|
| 233 |
+
|
| 234 |
+
### Core Libraries
|
| 235 |
+
- **FastAPI**: Web framework for API development
|
| 236 |
+
- **Hugging Face Hub**: Dataset downloading and management
|
| 237 |
+
- **Google GenAI**: Gemini model integration
|
| 238 |
+
- **ftfy**: Text encoding and normalization
|
| 239 |
+
- **langid**: Language detection
|
| 240 |
+
- **orjson**: High-performance JSON processing
|
| 241 |
+
|
| 242 |
+
### External Services
|
| 243 |
+
- **NVIDIA API**: Primary LLM service for paraphrasing
|
| 244 |
+
- **Google Gemini**: Fallback LLM service
|
| 245 |
+
- **Google Drive**: Cloud storage integration
|
| 246 |
+
- **Hugging Face Spaces**: Deployment platform
|
| 247 |
+
|
| 248 |
+
---
|
| 249 |
+
|
| 250 |
+
*This document provides a comprehensive overview of all data processing techniques implemented in the MedAI Processing project. For specific implementation details, refer to the individual module files in the `utils/` directory.*
|
Dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# Install system dependencies as root (no sudo!)
|
| 4 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 5 |
+
ca-certificates curl && rm -rf /var/lib/apt/lists/*
|
| 6 |
+
|
| 7 |
+
# Create non-root user
|
| 8 |
+
RUN useradd -m -u 1000 user
|
| 9 |
+
ENV HOME=/home/user
|
| 10 |
+
WORKDIR $HOME/app
|
| 11 |
+
|
| 12 |
+
# Install Python dependencies first (better layer caching)
|
| 13 |
+
COPY --chown=user requirements.txt .
|
| 14 |
+
RUN pip install --upgrade pip && pip install --no-cache-dir -r requirements.txt
|
| 15 |
+
|
| 16 |
+
# Copy the application
|
| 17 |
+
COPY --chown=user . .
|
| 18 |
+
|
| 19 |
+
# Hugging Face cache setup
|
| 20 |
+
ENV HF_HOME="$HOME/.cache/huggingface"
|
| 21 |
+
ENV SENTENCE_TRANSFORMERS_HOME="$HOME/.cache/huggingface/sentence-transformers"
|
| 22 |
+
ENV MEDGEMMA_HOME="$HOME/.cache/huggingface/sentence-transformers"
|
| 23 |
+
|
| 24 |
+
# Prepare runtime dirs
|
| 25 |
+
RUN mkdir -p $HOME/app/logs $HOME/app/cache $HOME/app/cache/hf $HOME/app/cache/outputs && \
|
| 26 |
+
chown -R user:user $HOME/app
|
| 27 |
+
|
| 28 |
+
USER user
|
| 29 |
+
|
| 30 |
+
EXPOSE 7860
|
| 31 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2025 Dang Khoa Le
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: MedAI Processing
|
| 3 |
+
emoji: ⚕️
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
short_description: Process and centralise medical doc for llm finetuning
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## Quick Access:
|
| 13 |
+
|
| 14 |
+
[HF Space](https://huggingface.co/spaces/BinKhoaLe1812/MedAI_Processing)
|
| 15 |
+
|
| 16 |
+
[MedDialog-100k](https://huggingface.co/datasets/BinKhoaLe1812/MedDialog-EN-100k)
|
| 17 |
+
|
| 18 |
+
[MedDialog-100k](https://huggingface.co/datasets/BinKhoaLe1812/MedDialog-EN-10k)
|
| 19 |
+
|
| 20 |
+
[PubMedQA-Labelled](https://huggingface.co/datasets/BinKhoaLe1812/PubMedQA-L)
|
| 21 |
+
|
| 22 |
+
[PubMedQA-Unlabelled](https://huggingface.co/datasets/BinKhoaLe1812/PubMedQA-U)
|
| 23 |
+
|
| 24 |
+
[PubMedQA-Mapper](https://huggingface.co/datasets/BinKhoaLe1812/PubMedQA-MAP)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
## CURL Request Instruction
|
| 28 |
+
[Request Doc](https://huggingface.co/spaces/MedAI-COS30018/MedAI_Processing/blob/main/REQUEST.md)
|
| 29 |
+
|
| 30 |
+
## License
|
| 31 |
+
[Apache-2.0 LICENSE](https://huggingface.co/spaces/BinKhoaLe1812/MedAI_Processing/blob/main/LICENSE.txt)
|
| 32 |
+
|
REQUEST.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📑 MedAI Processing – Request Examples
|
| 2 |
+
|
| 3 |
+
Base URL of the Space:
|
| 4 |
+
**`https://binkhoale1812-medai-processing.hf.space`**
|
| 5 |
+
|
| 6 |
+
This Space processes medical datasets into a centralised fine-tuning format (JSONL + CSV) with optional augmentations such as **paraphrasing**, **back-translation**, **style standardisation**, **de-identification**, and **deduplication**.
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## 🔹 1. Process HealthCareMagic
|
| 11 |
+
|
| 12 |
+
```bash
|
| 13 |
+
curl -X POST \
|
| 14 |
+
-H "Content-Type: application/json" \
|
| 15 |
+
-d '{
|
| 16 |
+
"augment": {
|
| 17 |
+
"paraphrase_ratio": 0.1,
|
| 18 |
+
"backtranslate_ratio": 0.05,
|
| 19 |
+
"paraphrase_outputs": false,
|
| 20 |
+
"style_standardize": true,
|
| 21 |
+
"deidentify": true,
|
| 22 |
+
"dedupe": true,
|
| 23 |
+
"max_chars": 5000
|
| 24 |
+
},
|
| 25 |
+
"sample_limit": 2000,
|
| 26 |
+
"seed": 42
|
| 27 |
+
}' \
|
| 28 |
+
https://binkhoale1812-medai-processing.hf.space/process/healthcaremagic
|
| 29 |
+
````
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## 🔹 2. Process iCliniq
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
curl -X POST \
|
| 37 |
+
-H "Content-Type: application/json" \
|
| 38 |
+
-d '{
|
| 39 |
+
"augment": {
|
| 40 |
+
"paraphrase_ratio": 0.2,
|
| 41 |
+
"backtranslate_ratio": 0.1,
|
| 42 |
+
"paraphrase_outputs": true,
|
| 43 |
+
"style_standardize": true,
|
| 44 |
+
"deidentify": true,
|
| 45 |
+
"dedupe": true,
|
| 46 |
+
"max_chars": 5000
|
| 47 |
+
},
|
| 48 |
+
"sample_limit": 1500,
|
| 49 |
+
"seed": 123
|
| 50 |
+
}' \
|
| 51 |
+
https://binkhoale1812-medai-processing.hf.space/process/icliniq
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
## 🔹 3. Process PubMedQA (Labelled)
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
curl -X POST \
|
| 60 |
+
-H "Content-Type: application/json" \
|
| 61 |
+
-d '{
|
| 62 |
+
"augment": {
|
| 63 |
+
"paraphrase_ratio": 0.05,
|
| 64 |
+
"backtranslate_ratio": 0.02,
|
| 65 |
+
"paraphrase_outputs": false,
|
| 66 |
+
"style_standardize": true,
|
| 67 |
+
"deidentify": false,
|
| 68 |
+
"dedupe": true,
|
| 69 |
+
"max_chars": 8000
|
| 70 |
+
},
|
| 71 |
+
"sample_limit": 1000,
|
| 72 |
+
"seed": 99
|
| 73 |
+
}' \
|
| 74 |
+
https://binkhoale1812-medai-processing.hf.space/process/pubmedqa_l
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## 🔹 4. Process PubMedQA (Unlabelled)
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
curl -X POST \
|
| 83 |
+
-H "Content-Type: application/json" \
|
| 84 |
+
-d '{
|
| 85 |
+
"augment": {
|
| 86 |
+
"paraphrase_ratio": 0.05,
|
| 87 |
+
"backtranslate_ratio": 0.05,
|
| 88 |
+
"paraphrase_outputs": false,
|
| 89 |
+
"style_standardize": true,
|
| 90 |
+
"deidentify": true,
|
| 91 |
+
"dedupe": true,
|
| 92 |
+
"max_chars": 7000,
|
| 93 |
+
"consistency_check_ratio": 0.01,
|
| 94 |
+
"distill_fraction": 0.1
|
| 95 |
+
},
|
| 96 |
+
"sample_limit": 500,
|
| 97 |
+
"seed": 7
|
| 98 |
+
}' \
|
| 99 |
+
https://binkhoale1812-medai-processing.hf.space/process/pubmedqa_u
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
## 🔹 5. Process PubMedQA (Map)
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
curl -X POST \
|
| 108 |
+
-H "Content-Type: application/json" \
|
| 109 |
+
-d '{
|
| 110 |
+
"augment": {
|
| 111 |
+
"paraphrase_ratio": 0.1,
|
| 112 |
+
"backtranslate_ratio": 0.05,
|
| 113 |
+
"paraphrase_outputs": true,
|
| 114 |
+
"style_standardize": true,
|
| 115 |
+
"deidentify": true,
|
| 116 |
+
"dedupe": true,
|
| 117 |
+
"max_chars": 6000
|
| 118 |
+
},
|
| 119 |
+
"sample_limit": 1200,
|
| 120 |
+
"seed": 2024
|
| 121 |
+
}' \
|
| 122 |
+
https://binkhoale1812-medai-processing.hf.space/process/pubmedqa_map
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
---
|
| 126 |
+
|
| 127 |
+
## 🔹 6. Check Current Job Status
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
curl https://binkhoale1812-medai-processing.hf.space/status
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
---
|
| 134 |
+
|
| 135 |
+
## 🔹 7. List Generated Artifacts
|
| 136 |
+
|
| 137 |
+
```bash
|
| 138 |
+
curl https://binkhoale1812-medai-processing.hf.space/files
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
---
|
| 142 |
+
|
| 143 |
+
# ✅ Notes
|
| 144 |
+
|
| 145 |
+
* Each run outputs both `.jsonl` and `.csv` in `cache/outputs/` and also uploads them to Google Drive folder ID:
|
| 146 |
+
`1JvW7its63E58fLxurH8ZdhxzdpcMrMbt`
|
| 147 |
+
* `augment` options can be adjusted per dataset:
|
| 148 |
+
|
| 149 |
+
* `paraphrase_ratio` – % of rows paraphrased (0–1)
|
| 150 |
+
* `backtranslate_ratio` – % of rows back-translated
|
| 151 |
+
* `paraphrase_outputs` – whether to also augment model answers
|
| 152 |
+
* `style_standardize` – enforce neutral, clinical style
|
| 153 |
+
* `deidentify` – redact PHI (emails, phones, URLs, IPs)
|
| 154 |
+
* `dedupe` – skip duplicate pairs
|
| 155 |
+
* `consistency_check_ratio` – run lightweight QA sanity check
|
| 156 |
+
* `distill_fraction` – generate pseudo-labels for unlabelled data
|
app.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Root FastAPI
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import time, logging
|
| 5 |
+
import threading
|
| 6 |
+
import datetime as dt
|
| 7 |
+
from typing import Optional, Dict
|
| 8 |
+
|
| 9 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
| 10 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
|
| 14 |
+
from utils.datasets import resolve_dataset, hf_download_dataset
|
| 15 |
+
from utils.processor import process_file_into_sft
|
| 16 |
+
from utils.rag import process_file_into_rag
|
| 17 |
+
from utils.drive_saver import DriveSaver
|
| 18 |
+
from utils.llm import Paraphraser
|
| 19 |
+
from utils.schema import CentralisedWriter
|
| 20 |
+
from utils.token import get_credentials, exchange_code, build_auth_url
|
| 21 |
+
|
| 22 |
+
# ────────── Log ───────────
|
| 23 |
+
logger = logging.getLogger("app")
|
| 24 |
+
if not logger.handlers:
|
| 25 |
+
logger.setLevel(logging.INFO)
|
| 26 |
+
handler = logging.StreamHandler()
|
| 27 |
+
logger.addHandler(handler)
|
| 28 |
+
|
| 29 |
+
# ────────── Boot ──────────
|
| 30 |
+
load_dotenv(override=True)
|
| 31 |
+
|
| 32 |
+
SPACE_NAME = os.getenv("SPACE_NAME", "MedAI Processor")
|
| 33 |
+
OUTPUT_DIR = os.path.abspath(os.getenv("OUTPUT_DIR", "cache/outputs"))
|
| 34 |
+
LOG_DIR = os.path.abspath(os.getenv("LOG_DIR", "logs"))
|
| 35 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 36 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
# --- Bootstrap Google OAuth ---
|
| 39 |
+
try:
|
| 40 |
+
creds = get_credentials()
|
| 41 |
+
if creds:
|
| 42 |
+
logger.info("✅ OAuth credentials loaded and valid")
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.warning(f"⚠️ OAuth not initialized yet: {e}")
|
| 45 |
+
|
| 46 |
+
# --- Bootstrap Google Drive ---
|
| 47 |
+
drive = DriveSaver(default_folder_id=os.getenv("GDRIVE_FOLDER_ID"))
|
| 48 |
+
|
| 49 |
+
# LLM rotator with paraphraser nodes
|
| 50 |
+
paraphraser = Paraphraser(
|
| 51 |
+
nvidia_model=os.getenv("NVIDIA_MODEL", "meta/llama-3.1-8b-instruct"),
|
| 52 |
+
gemini_model_easy=os.getenv("GEMINI_MODEL_EASY", "gemini-2.5-flash-lite"),
|
| 53 |
+
gemini_model_hard=os.getenv("GEMINI_MODEL_HARD", "gemini-2.5-flash"),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
app = FastAPI(title="Medical Dataset Augmenter", version="1.1.0")
|
| 57 |
+
|
| 58 |
+
STATE_LOCK = threading.Lock()
|
| 59 |
+
STATE: Dict[str, object] = {
|
| 60 |
+
"running": False,
|
| 61 |
+
"dataset": None,
|
| 62 |
+
"started_at": None,
|
| 63 |
+
"progress": 0.0,
|
| 64 |
+
"message": "idle",
|
| 65 |
+
"last_result": None
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
class AugmentOptions(BaseModel):
|
| 69 |
+
# ratios are 0..1
|
| 70 |
+
paraphrase_ratio: float = 0.0
|
| 71 |
+
paraphrase_outputs: bool = False
|
| 72 |
+
backtranslate_ratio: float = 0.0
|
| 73 |
+
style_standardize: bool = True
|
| 74 |
+
deidentify: bool = True
|
| 75 |
+
dedupe: bool = True
|
| 76 |
+
max_chars: int = 5000 # cap extremely long contexts
|
| 77 |
+
consistency_check_ratio: float = 0.0 # small ratio e.g. 0.01
|
| 78 |
+
# KD / distillation (optional, keeps default off)
|
| 79 |
+
distill_fraction: float = 0.0 # for unlabeled only
|
| 80 |
+
expand: bool = True # Enable back-translation and complex augmentation
|
| 81 |
+
max_aug_per_sample: int = 2 # Between 1-3, number of LLM call to augment/paraphrase data
|
| 82 |
+
|
| 83 |
+
class ProcessParams(BaseModel):
|
| 84 |
+
augment: AugmentOptions = AugmentOptions()
|
| 85 |
+
sample_limit: Optional[int] = None # Set data sampling if needed
|
| 86 |
+
seed: int = 42
|
| 87 |
+
rag_processing: bool = False # Enable RAG-specific processing
|
| 88 |
+
|
| 89 |
+
def set_state(**kwargs):
|
| 90 |
+
with STATE_LOCK:
|
| 91 |
+
STATE.update(kwargs)
|
| 92 |
+
|
| 93 |
+
def now_iso():
|
| 94 |
+
return dt.datetime.utcnow().isoformat()
|
| 95 |
+
|
| 96 |
+
# Instructional UI
|
| 97 |
+
@app.get("/", response_class=HTMLResponse)
|
| 98 |
+
def root():
|
| 99 |
+
return f"""
|
| 100 |
+
<html>
|
| 101 |
+
<head>
|
| 102 |
+
<title>{SPACE_NAME} – Medical Dataset Augmenter</title>
|
| 103 |
+
<style>
|
| 104 |
+
body {{ font-family: Arial, sans-serif; max-width: 900px; margin: 2rem auto; line-height: 1.5; }}
|
| 105 |
+
h1, h2 {{ color: #2c3e50; }}
|
| 106 |
+
button {{
|
| 107 |
+
background: #2d89ef; color: white; border: none; padding: 8px 16px;
|
| 108 |
+
border-radius: 5px; cursor: pointer; margin: 5px 0;
|
| 109 |
+
}}
|
| 110 |
+
button:hover {{ background: #1b5dab; }}
|
| 111 |
+
.section {{ margin-bottom: 2rem; }}
|
| 112 |
+
#log {{ background:#f5f5f5; padding:10px; border-radius:6px; margin-top:10px; font-size:0.9rem; }}
|
| 113 |
+
a {{ color:#2d89ef; text-decoration:none; }}
|
| 114 |
+
a:hover {{ text-decoration:underline; }}
|
| 115 |
+
</style>
|
| 116 |
+
</head>
|
| 117 |
+
<body>
|
| 118 |
+
<h1>📊 {SPACE_NAME} – Medical Dataset Augmenter</h1>
|
| 119 |
+
<p>This Hugging Face Space processes medical datasets into a <b>centralised fine-tuning format</b>
|
| 120 |
+
(JSONL + CSV), with optional <i>data augmentation</i>.</p>
|
| 121 |
+
|
| 122 |
+
<div class="section">
|
| 123 |
+
<h2>⚡ Quick Actions</h2>
|
| 124 |
+
<p>Click a button below to start processing a dataset with default augmentation parameters.</p>
|
| 125 |
+
<button onclick="startJob('healthcaremagic')">▶ProcAugment HealthCareMagic (100k)</button><br>
|
| 126 |
+
<button onclick="startJob('icliniq')">▶ProcAugment iCliniq (10k-derived)</button><br>
|
| 127 |
+
<button onclick="startJob('pubmedqa_l')">▶ProcAugment PubMedQA (Labelled)</button><br>
|
| 128 |
+
<button onclick="startJob('pubmedqa_u')">▶ProcAugment PubMedQA (Unlabelled)</button><br>
|
| 129 |
+
<button onclick="startJob('pubmedqa_map')">▶ProcAugment PubMedQA (Map)</button><br><br>
|
| 130 |
+
<div style="border-top: 1px solid #ddd; padding-top: 10px; margin-top: 10px;">
|
| 131 |
+
<strong>RAG Processing:</strong> - Convert to QCA format for RAG systems<br>
|
| 132 |
+
<button onclick="startRagJob('healthcaremagic')" style="background: #e74c3c;">▶ RAG HealthCareMagic (100k)</button><br>
|
| 133 |
+
<button onclick="startRagJob('icliniq')" style="background: #e74c3c;">▶ RAG iCliniq (10k-derived)</button><br>
|
| 134 |
+
<button onclick="startRagJob('pubmedqa_u')" style="background: #e74c3c;">▶ RAG PubMedQA (Unlabelled)</button><br>
|
| 135 |
+
<button onclick="startRagJob('pubmedqa_l')" style="background: #e74c3c;">▶ RAG PubMedQA (Labelled)</button><br>
|
| 136 |
+
<button onclick="startRagJob('pubmedqa_map')" style="background: #e74c3c;">▶ RAG PubMedQA (Map)</button>
|
| 137 |
+
</div>
|
| 138 |
+
</div>
|
| 139 |
+
|
| 140 |
+
<div class="section">
|
| 141 |
+
<h2>📂 Monitoring</h2>
|
| 142 |
+
<ul>
|
| 143 |
+
<li><a href="/status" target="_blank">Check current job status</a></li>
|
| 144 |
+
<li><a href="/files" target="_blank">List generated artifacts</a></li>
|
| 145 |
+
<li><a href="https://binkhoale1812-medai-processing.hf.space/oauth2/start" target="_blank">Authorize your GCS credential</a></li>
|
| 146 |
+
<li><a href="https://huggingface.co/spaces/BinKhoaLe1812/MedAI_Processing/blob/main/REQUEST.md" target="_blank">📑 Request Doc (all curl examples)</a></li>
|
| 147 |
+
</ul>
|
| 148 |
+
</div>
|
| 149 |
+
|
| 150 |
+
<div class="section">
|
| 151 |
+
<h2>📝 Log</h2>
|
| 152 |
+
<div id="log">Click a button above to run a job...</div>
|
| 153 |
+
</div>
|
| 154 |
+
|
| 155 |
+
<script>
|
| 156 |
+
async function startJob(dataset) {{
|
| 157 |
+
const log = document.getElementById("log");
|
| 158 |
+
const ragToggle = document.getElementById("ragToggle");
|
| 159 |
+
const isRagMode = ragToggle.checked;
|
| 160 |
+
|
| 161 |
+
log.innerHTML = "⏳ Starting " + (isRagMode ? "RAG " : "") + "job for <b>" + dataset + "</b>...";
|
| 162 |
+
try {{
|
| 163 |
+
const resp = await fetch("/process/" + dataset, {{
|
| 164 |
+
method: "POST",
|
| 165 |
+
headers: {{ "Content-Type": "application/json" }},
|
| 166 |
+
body: JSON.stringify({{
|
| 167 |
+
augment: {{
|
| 168 |
+
paraphrase_ratio: 0.1,
|
| 169 |
+
backtranslate_ratio: 0.00, // Increase to 0.05-0.1 for back-translation
|
| 170 |
+
paraphrase_outputs: false,
|
| 171 |
+
style_standardize: true,
|
| 172 |
+
deidentify: true,
|
| 173 |
+
dedupe: true,
|
| 174 |
+
max_chars: 5000,
|
| 175 |
+
expand: true,
|
| 176 |
+
max_aug_per_sample: 2
|
| 177 |
+
}},
|
| 178 |
+
sample_limit: null, // Sample down (currently disabled)
|
| 179 |
+
seed: 42,
|
| 180 |
+
rag_processing: isRagMode
|
| 181 |
+
}})
|
| 182 |
+
}});
|
| 183 |
+
const data = await resp.json();
|
| 184 |
+
if (resp.ok) {{
|
| 185 |
+
log.innerHTML = "✅ " + JSON.stringify(data);
|
| 186 |
+
}} else {{
|
| 187 |
+
log.innerHTML = "❌ Error: " + JSON.stringify(data);
|
| 188 |
+
}}
|
| 189 |
+
}} catch (err) {{
|
| 190 |
+
log.innerHTML = "❌ JS Error: " + err;
|
| 191 |
+
}}
|
| 192 |
+
}}
|
| 193 |
+
|
| 194 |
+
async function startRagJob(dataset) {{
|
| 195 |
+
const log = document.getElementById("log");
|
| 196 |
+
log.innerHTML = "⏳ Starting RAG processing for <b>" + dataset + "</b>...";
|
| 197 |
+
try {{
|
| 198 |
+
const resp = await fetch("/rag/" + dataset, {{
|
| 199 |
+
method: "POST",
|
| 200 |
+
headers: {{ "Content-Type": "application/json" }},
|
| 201 |
+
body: JSON.stringify({{
|
| 202 |
+
sample_limit: null,
|
| 203 |
+
seed: 42
|
| 204 |
+
}})
|
| 205 |
+
}});
|
| 206 |
+
const data = await resp.json();
|
| 207 |
+
if (resp.ok) {{
|
| 208 |
+
log.innerHTML = "✅ RAG Processing Started: " + JSON.stringify(data);
|
| 209 |
+
}} else {{
|
| 210 |
+
log.innerHTML = "❌ Error: " + JSON.stringify(data);
|
| 211 |
+
}}
|
| 212 |
+
}} catch (err) {{
|
| 213 |
+
log.innerHTML = "❌ JS Error: " + err;
|
| 214 |
+
}}
|
| 215 |
+
}}
|
| 216 |
+
</script>
|
| 217 |
+
</body>
|
| 218 |
+
</html>
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
@app.get("/status")
|
| 222 |
+
def status():
|
| 223 |
+
with STATE_LOCK:
|
| 224 |
+
return JSONResponse(STATE)
|
| 225 |
+
|
| 226 |
+
# ──────── GCS token ────────
|
| 227 |
+
@app.get("/oauth2/start")
|
| 228 |
+
def oauth2_start(request: Request):
|
| 229 |
+
# Compute redirect URI dynamically from the actual host the Space is using
|
| 230 |
+
host = request.headers.get("x-forwarded-host") or request.headers.get("host")
|
| 231 |
+
scheme = "https" # Spaces are HTTPS at the edge
|
| 232 |
+
redirect_uri = f"{scheme}://{host}/oauth2/callback"
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
url = build_auth_url(redirect_uri)
|
| 236 |
+
return JSONResponse({"authorize_url": url})
|
| 237 |
+
except Exception as e:
|
| 238 |
+
raise HTTPException(500, f"OAuth init failed: {e}")
|
| 239 |
+
|
| 240 |
+
# Display your token
|
| 241 |
+
@app.get("/oauth2/callback")
|
| 242 |
+
def oauth2_callback(request: Request, code: str = "", state: str = ""):
|
| 243 |
+
if not code:
|
| 244 |
+
raise HTTPException(400, "Missing 'code'")
|
| 245 |
+
# Send req
|
| 246 |
+
host = request.headers.get("x-forwarded-host") or request.headers.get("host")
|
| 247 |
+
scheme = "https"
|
| 248 |
+
redirect_uri = f"{scheme}://{host}/oauth2/callback"
|
| 249 |
+
# Parse and show token code
|
| 250 |
+
try:
|
| 251 |
+
creds = exchange_code(code, redirect_uri)
|
| 252 |
+
refresh = creds.refresh_token or os.getenv("GDRIVE_REFRESH_TOKEN", "")
|
| 253 |
+
# UI
|
| 254 |
+
html = f"""
|
| 255 |
+
<html>
|
| 256 |
+
<head>
|
| 257 |
+
<style>
|
| 258 |
+
body {{ font-family: sans-serif; margin: 2em; }}
|
| 259 |
+
.token-box {{
|
| 260 |
+
padding: 1em; border: 1px solid #ccc; border-radius: 6px;
|
| 261 |
+
background: #f9f9f9; font-family: monospace;
|
| 262 |
+
word-break: break-all; white-space: pre-wrap;
|
| 263 |
+
}}
|
| 264 |
+
.note {{ margin-top: 1em; color: #555; }}
|
| 265 |
+
</style>
|
| 266 |
+
</head>
|
| 267 |
+
<body>
|
| 268 |
+
<h2>✅ Google Drive Authorized</h2>
|
| 269 |
+
<p>Your refresh token is:</p>
|
| 270 |
+
<div class="token-box">{refresh}</div>
|
| 271 |
+
<p class="note">
|
| 272 |
+
👉 Copy this token and save it into your Hugging Face Space Secrets
|
| 273 |
+
as <code>GDRIVE_REFRESH_TOKEN</code>.
|
| 274 |
+
This ensures persistence across rebuilds.
|
| 275 |
+
</p>
|
| 276 |
+
</body>
|
| 277 |
+
</html>
|
| 278 |
+
"""
|
| 279 |
+
return HTMLResponse(html)
|
| 280 |
+
except Exception as e:
|
| 281 |
+
raise HTTPException(500, f"OAuth exchange failed: {e}")
|
| 282 |
+
|
| 283 |
+
@app.get("/files")
|
| 284 |
+
def files():
|
| 285 |
+
out = []
|
| 286 |
+
for root, _, fns in os.walk(OUTPUT_DIR):
|
| 287 |
+
for fn in fns:
|
| 288 |
+
out.append(os.path.relpath(os.path.join(root, fn), OUTPUT_DIR))
|
| 289 |
+
return {"output_dir": OUTPUT_DIR, "files": sorted(out)}
|
| 290 |
+
|
| 291 |
+
@app.post("/process/{dataset_key}")
|
| 292 |
+
def process_dataset(dataset_key: str, params: ProcessParams, background: BackgroundTasks):
|
| 293 |
+
with STATE_LOCK:
|
| 294 |
+
if STATE["running"]:
|
| 295 |
+
logger.warning(
|
| 296 |
+
f"[JOB] Rejecting new job dataset={dataset_key} "
|
| 297 |
+
f"current={STATE['dataset']} started_at={STATE['started_at']}"
|
| 298 |
+
)
|
| 299 |
+
raise HTTPException(409, detail="Another job is running.")
|
| 300 |
+
STATE["running"] = True
|
| 301 |
+
STATE["dataset"] = dataset_key
|
| 302 |
+
STATE["started_at"] = now_iso()
|
| 303 |
+
STATE["progress"] = 0.0
|
| 304 |
+
STATE["message"] = "starting"
|
| 305 |
+
STATE["last_result"] = None
|
| 306 |
+
logger.info(
|
| 307 |
+
f"[JOB] Queued dataset={dataset_key} "
|
| 308 |
+
f"params={{'sample_limit': {params.sample_limit}, 'seed': {params.seed}, "
|
| 309 |
+
f"'rag_processing': {params.rag_processing}, 'augment': {params.augment.dict()} }}"
|
| 310 |
+
)
|
| 311 |
+
# Start job to background runner thread
|
| 312 |
+
logger.info(f"[JOB] Started dataset={dataset_key}")
|
| 313 |
+
background.add_task(_run_job, dataset_key, params)
|
| 314 |
+
return {"ok": True, "message": f"Job for '{dataset_key}' started."}
|
| 315 |
+
|
| 316 |
+
@app.post("/rag/{dataset_key}")
|
| 317 |
+
def process_rag_dataset(dataset_key: str, params: ProcessParams, background: BackgroundTasks):
|
| 318 |
+
"""Dedicated RAG processing endpoint"""
|
| 319 |
+
# Force RAG processing mode
|
| 320 |
+
params.rag_processing = True
|
| 321 |
+
|
| 322 |
+
with STATE_LOCK:
|
| 323 |
+
if STATE["running"]:
|
| 324 |
+
logger.warning(
|
| 325 |
+
f"[RAG] Rejecting new RAG job dataset={dataset_key} "
|
| 326 |
+
f"current={STATE['dataset']} started_at={STATE['started_at']}"
|
| 327 |
+
)
|
| 328 |
+
raise HTTPException(409, detail="Another job is running.")
|
| 329 |
+
STATE["running"] = True
|
| 330 |
+
STATE["dataset"] = dataset_key
|
| 331 |
+
STATE["started_at"] = now_iso()
|
| 332 |
+
STATE["progress"] = 0.0
|
| 333 |
+
STATE["message"] = "starting RAG processing"
|
| 334 |
+
STATE["last_result"] = None
|
| 335 |
+
logger.info(
|
| 336 |
+
f"[RAG] Queued RAG dataset={dataset_key} "
|
| 337 |
+
f"params={{'sample_limit': {params.sample_limit}, 'seed': {params.seed} }}"
|
| 338 |
+
)
|
| 339 |
+
# Start job to background runner thread
|
| 340 |
+
logger.info(f"[RAG] Started RAG dataset={dataset_key}")
|
| 341 |
+
background.add_task(_run_job, dataset_key, params)
|
| 342 |
+
return {"ok": True, "message": f"RAG processing job for '{dataset_key}' started."}
|
| 343 |
+
|
| 344 |
+
def _run_job(dataset_key: str, params: ProcessParams):
|
| 345 |
+
t0 = time.time()
|
| 346 |
+
try:
|
| 347 |
+
ds = resolve_dataset(dataset_key)
|
| 348 |
+
if not ds:
|
| 349 |
+
set_state(running=False, message="unknown dataset")
|
| 350 |
+
return
|
| 351 |
+
|
| 352 |
+
# Download HF Dataset and start processing units
|
| 353 |
+
set_state(message="downloading")
|
| 354 |
+
local_path = hf_download_dataset(ds["repo_id"], ds["filename"], ds["repo_type"])
|
| 355 |
+
logger.info(f"[JOB] Downloaded {ds['repo_id']}/{ds['filename']} → {local_path}")
|
| 356 |
+
|
| 357 |
+
# Prepare timestamp for fire writing
|
| 358 |
+
ts = dt.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
| 359 |
+
mode_suffix = "rag" if params.rag_processing else "sft"
|
| 360 |
+
stem = f"{dataset_key}-{mode_suffix}-{ts}"
|
| 361 |
+
jsonl_path = os.path.join(OUTPUT_DIR, f"{stem}.jsonl")
|
| 362 |
+
csv_path = os.path.join(OUTPUT_DIR, f"{stem}.csv")
|
| 363 |
+
# Change state
|
| 364 |
+
set_state(message="processing", progress=0.05)
|
| 365 |
+
|
| 366 |
+
# Writer
|
| 367 |
+
writer = CentralisedWriter(jsonl_path=jsonl_path, csv_path=csv_path)
|
| 368 |
+
|
| 369 |
+
if params.rag_processing:
|
| 370 |
+
# RAG processing mode
|
| 371 |
+
set_state(message="RAG processing", progress=0.1)
|
| 372 |
+
count, stats = process_file_into_rag(
|
| 373 |
+
dataset_key=dataset_key,
|
| 374 |
+
input_path=local_path,
|
| 375 |
+
writer=writer,
|
| 376 |
+
nvidia_model=os.getenv("NVIDIA_MODEL", "meta/llama-3.1-8b-instruct"),
|
| 377 |
+
sample_limit=params.sample_limit,
|
| 378 |
+
seed=params.seed,
|
| 379 |
+
progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"])
|
| 380 |
+
)
|
| 381 |
+
else:
|
| 382 |
+
# Standard SFT processing mode
|
| 383 |
+
set_state(message="SFT processing", progress=0.1)
|
| 384 |
+
count, stats = process_file_into_sft(
|
| 385 |
+
dataset_key=dataset_key,
|
| 386 |
+
input_path=local_path,
|
| 387 |
+
writer=writer,
|
| 388 |
+
paraphraser=paraphraser,
|
| 389 |
+
augment_opts=params.augment.dict(),
|
| 390 |
+
sample_limit=params.sample_limit,
|
| 391 |
+
seed=params.seed,
|
| 392 |
+
progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"])
|
| 393 |
+
)
|
| 394 |
+
logger.info(f"[JOB] Processed dataset={dataset_key} rows={count} stats={stats}")
|
| 395 |
+
writer.close()
|
| 396 |
+
|
| 397 |
+
# Upload to GDrive
|
| 398 |
+
set_state(message="uploading to Google Drive", progress=0.95)
|
| 399 |
+
up1 = drive.upload_file_to_drive(jsonl_path, mimetype="application/json")
|
| 400 |
+
up2 = drive.upload_file_to_drive(csv_path, mimetype="text/csv")
|
| 401 |
+
logger.info(
|
| 402 |
+
f"[JOB] Uploads complete uploaded={bool(up1 and up2)} "
|
| 403 |
+
f"jsonl={jsonl_path} csv={csv_path}"
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Finalize a task
|
| 407 |
+
result = {
|
| 408 |
+
"dataset": dataset_key,
|
| 409 |
+
"processing_mode": "RAG" if params.rag_processing else "SFT",
|
| 410 |
+
"processed_rows": count,
|
| 411 |
+
"stats": stats,
|
| 412 |
+
"artifacts": {"jsonl": jsonl_path, "csv": csv_path},
|
| 413 |
+
"uploaded": bool(up1 and up2),
|
| 414 |
+
"duration_sec": round(time.time() - t0, 2)
|
| 415 |
+
}
|
| 416 |
+
set_state(message="done", progress=1.0, last_result=result, running=False)
|
| 417 |
+
logger.info(
|
| 418 |
+
f"[JOB] Finished dataset={dataset_key} "
|
| 419 |
+
f"duration_sec={round(time.time()-t0, 2)}"
|
| 420 |
+
)
|
| 421 |
+
except Exception as e:
|
| 422 |
+
logger.exception(f"[JOB] Error for dataset={dataset_key}: {e}")
|
| 423 |
+
set_state(message=f"error: {e}", running=False)
|
mount_drive.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Check Google Drive status
|
| 2 |
+
from utils.drive_saver import DriveSaver
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
ds = DriveSaver()
|
| 6 |
+
if ds.is_service_available():
|
| 7 |
+
print("Drive ready.")
|
| 8 |
+
else:
|
| 9 |
+
print("Drive NOT ready.")
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
python-dotenv
|
| 4 |
+
huggingface_hub
|
| 5 |
+
requests
|
| 6 |
+
google-genai
|
| 7 |
+
google-api-python-client
|
| 8 |
+
google-auth
|
| 9 |
+
google-auth-httplib2
|
| 10 |
+
google-auth-oauthlib
|
| 11 |
+
orjson
|
| 12 |
+
ftfy
|
| 13 |
+
langid
|
utils/ __init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility package for the Medical Dataset Augmenter Space.
|
| 3 |
+
|
| 4 |
+
This package provides:
|
| 5 |
+
- drive_saver: Google Drive upload helper
|
| 6 |
+
- llm: API key rotation, paraphraser, translation/backtranslation
|
| 7 |
+
- datasets: Hugging Face dataset resolver & downloader
|
| 8 |
+
- processor: dataset-specific processing pipeline with augmentation
|
| 9 |
+
- schema: centralised SFT writer (JSONL + CSV)
|
| 10 |
+
- token: GCS project token refresher and authenticator
|
| 11 |
+
- augment: low-level augmentation utilities (text cleanup, deid, paraphrase hooks)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from . import drive_saver
|
| 15 |
+
from . import llm
|
| 16 |
+
from . import datasets
|
| 17 |
+
from . import processor
|
| 18 |
+
from . import schema
|
| 19 |
+
from . import augment
|
| 20 |
+
from . import token
|
| 21 |
+
|
| 22 |
+
__all__ = ["drive_saver", "llm", "datasets", "processor", "schema", "augment"]
|
utils/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
utils/augment.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# augmentation utility agent
|
| 2 |
+
import re
|
| 3 |
+
import random
|
| 4 |
+
from typing import Dict, Tuple
|
| 5 |
+
import ftfy
|
| 6 |
+
import langid
|
| 7 |
+
|
| 8 |
+
P_EMAIL = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}")
|
| 9 |
+
P_PHONE = re.compile(r"(?:(?:\+?\d{1,3})?[\s-]?)?(?:\(?\d{2,4}\)?[\s-]?)?\d{3,4}[\s-]?\d{3,4}")
|
| 10 |
+
P_URL = re.compile(r"https?://\S+|www\.\S+")
|
| 11 |
+
P_IP = re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b")
|
| 12 |
+
|
| 13 |
+
def fix_unicode(s: str) -> str:
|
| 14 |
+
return ftfy.fix_text(s or "")
|
| 15 |
+
|
| 16 |
+
def normalize_whitespace(s: str) -> str:
|
| 17 |
+
s = s.replace("\u00A0", " ")
|
| 18 |
+
s = re.sub(r"[ \t]+", " ", s)
|
| 19 |
+
s = re.sub(r"\s+\n", "\n", s)
|
| 20 |
+
s = re.sub(r"\n{3,}", "\n\n", s)
|
| 21 |
+
return s.strip()
|
| 22 |
+
|
| 23 |
+
def canonicalize_quotes(s: str) -> str:
|
| 24 |
+
return s.replace("“", '"').replace("”", '"').replace("’", "'").replace("‘", "'")
|
| 25 |
+
|
| 26 |
+
def ensure_terminal_punct(s: str) -> str:
|
| 27 |
+
if not s: return s
|
| 28 |
+
if s[-1] in ".!?": return s
|
| 29 |
+
return s + "."
|
| 30 |
+
|
| 31 |
+
def deidentify(s: str) -> str:
|
| 32 |
+
s = P_EMAIL.sub("[REDACTED_EMAIL]", s)
|
| 33 |
+
s = P_PHONE.sub("[REDACTED_PHONE]", s)
|
| 34 |
+
s = P_URL.sub("[REDACTED_URL]", s)
|
| 35 |
+
s = P_IP.sub("[REDACTED_IP]", s)
|
| 36 |
+
return s
|
| 37 |
+
|
| 38 |
+
def lang_is_english(s: str) -> bool:
|
| 39 |
+
try:
|
| 40 |
+
lang, _ = langid.classify((s or "")[:2000])
|
| 41 |
+
return lang == "en"
|
| 42 |
+
except Exception:
|
| 43 |
+
return True
|
| 44 |
+
|
| 45 |
+
def length_cap(s: str, max_chars: int) -> str:
|
| 46 |
+
if len(s) <= max_chars:
|
| 47 |
+
return s
|
| 48 |
+
# try to cut at sentence boundary
|
| 49 |
+
cut = s[:max_chars]
|
| 50 |
+
last_dot = cut.rfind(". ")
|
| 51 |
+
if last_dot > 300: # don't cut too aggressively
|
| 52 |
+
return cut[:last_dot+1] + " …"
|
| 53 |
+
return cut + " …"
|
| 54 |
+
|
| 55 |
+
def fingerprint(instr: str, user: str, out: str) -> str:
|
| 56 |
+
# Simple, fast fingerprint for dedupe
|
| 57 |
+
def norm(x: str) -> str:
|
| 58 |
+
x = x.lower()
|
| 59 |
+
x = re.sub(r"[^a-z0-9]+", " ", x)
|
| 60 |
+
x = re.sub(r"\s+", " ", x).strip()
|
| 61 |
+
return x
|
| 62 |
+
core = "||".join([norm(instr), norm(user), norm(out)])
|
| 63 |
+
# lightweight hash
|
| 64 |
+
import hashlib
|
| 65 |
+
return hashlib.md5(core.encode("utf-8")).hexdigest()
|
| 66 |
+
|
| 67 |
+
def style_standardize_answer(ans: str) -> str:
|
| 68 |
+
if not ans: return ans
|
| 69 |
+
ans = ans.strip()
|
| 70 |
+
# Gentle guardrails, neutral voice
|
| 71 |
+
prefix = ""
|
| 72 |
+
# Avoid absolute guarantees
|
| 73 |
+
ans = re.sub(r"\b(guarantee|100%|certainly|always|never)\b", "likely", ans, flags=re.I)
|
| 74 |
+
# Remove sign-offs typical of forums
|
| 75 |
+
ans = re.sub(r"\n*(thanks|thank you|regards|cheers)[^\n]*$", "", ans, flags=re.I)
|
| 76 |
+
return ensure_terminal_punct(ans)
|
| 77 |
+
|
| 78 |
+
def base_cleanup(s: str, max_chars: int, do_deid: bool) -> str:
|
| 79 |
+
s = fix_unicode(s)
|
| 80 |
+
s = canonicalize_quotes(s)
|
| 81 |
+
s = normalize_whitespace(s)
|
| 82 |
+
if do_deid:
|
| 83 |
+
s = deidentify(s)
|
| 84 |
+
s = length_cap(s, max_chars)
|
| 85 |
+
return s
|
| 86 |
+
|
| 87 |
+
def maybe_paraphrase(text: str, ratio: float, paraphraser, difficulty: str) -> Tuple[str, bool]:
|
| 88 |
+
if ratio <= 0 or not text: return text, False
|
| 89 |
+
if random.random() < ratio:
|
| 90 |
+
return paraphraser.paraphrase(text, difficulty=difficulty), True
|
| 91 |
+
return text, False
|
| 92 |
+
|
| 93 |
+
def maybe_backtranslate(text: str, ratio: float, paraphraser) -> Tuple[str, bool]:
|
| 94 |
+
if ratio <= 0 or not text: return text, False
|
| 95 |
+
if random.random() < ratio:
|
| 96 |
+
bt = paraphraser.backtranslate(text, via_lang="de")
|
| 97 |
+
return bt if bt else text, bool(bt)
|
| 98 |
+
return text, False
|
| 99 |
+
|
| 100 |
+
def consistency_ok(user: str, out: str, ratio: float, paraphraser) -> bool:
|
| 101 |
+
if ratio <= 0 or (not user) or (not out):
|
| 102 |
+
return True
|
| 103 |
+
if random.random() >= ratio:
|
| 104 |
+
return True
|
| 105 |
+
return paraphraser.consistency_check(user, out)
|
utils/datasets.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HF dataset download resolver + downloader
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
# Logger
|
| 8 |
+
logger = logging.getLogger("datasets")
|
| 9 |
+
if not logger.handlers:
|
| 10 |
+
logger.setLevel(logging.INFO)
|
| 11 |
+
logger.addHandler(logging.StreamHandler())
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
DATASETS = {
|
| 15 |
+
"healthcaremagic": {
|
| 16 |
+
"repo_id": "BinKhoaLe1812/MedDialog-EN-100k",
|
| 17 |
+
"filename": "HealthCareMagic-100k.json",
|
| 18 |
+
"repo_type": "dataset"
|
| 19 |
+
},
|
| 20 |
+
"icliniq": {
|
| 21 |
+
"repo_id": "BinKhoaLe1812/MedDialog-EN-10k",
|
| 22 |
+
"filename": "iCliniq.json",
|
| 23 |
+
"repo_type": "dataset"
|
| 24 |
+
},
|
| 25 |
+
"pubmedqa_l": {
|
| 26 |
+
"repo_id": "BinKhoaLe1812/PubMedQA-L",
|
| 27 |
+
"filename": "ori_pqal.json",
|
| 28 |
+
"repo_type": "dataset"
|
| 29 |
+
},
|
| 30 |
+
"pubmedqa_u": {
|
| 31 |
+
"repo_id": "BinKhoaLe1812/PubMedQA-U",
|
| 32 |
+
"filename": "ori_pqau.json",
|
| 33 |
+
"repo_type": "dataset"
|
| 34 |
+
},
|
| 35 |
+
"pubmedqa_map": {
|
| 36 |
+
"repo_id": "BinKhoaLe1812/PubMedQA-Map",
|
| 37 |
+
"filename": "pubmed_qa_map.json",
|
| 38 |
+
"repo_type": "dataset"
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def resolve_dataset(key: str) -> Optional[dict]:
|
| 44 |
+
return DATASETS.get(key.lower())
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def hf_download_dataset(repo_id: str, filename: str, repo_type: str = "dataset") -> str:
|
| 48 |
+
token = os.getenv("HF_TOKEN")
|
| 49 |
+
logger.info(
|
| 50 |
+
f"[HF] Download {repo_id}/{filename} (type={repo_type}) token={'yes' if token else 'no'}"
|
| 51 |
+
)
|
| 52 |
+
path = hf_hub_download(
|
| 53 |
+
repo_id=repo_id,
|
| 54 |
+
filename=filename,
|
| 55 |
+
repo_type=repo_type,
|
| 56 |
+
token=token,
|
| 57 |
+
local_dir=os.path.abspath("cache/hf"),
|
| 58 |
+
local_dir_use_symlinks=False
|
| 59 |
+
)
|
| 60 |
+
try:
|
| 61 |
+
size = os.path.getsize(path)
|
| 62 |
+
logger.info(f"[HF] Downloaded to {path} size={size} bytes")
|
| 63 |
+
except Exception:
|
| 64 |
+
logger.info(f"[HF] Downloaded to {path}")
|
| 65 |
+
return path
|
| 66 |
+
|
utils/drive_saver.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Save final post-process to Google Drive
|
| 2 |
+
import os, json, logging
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from google.oauth2 import service_account
|
| 5 |
+
from googleapiclient.discovery import build
|
| 6 |
+
from googleapiclient.http import MediaFileUpload
|
| 7 |
+
|
| 8 |
+
from utils.token import get_credentials
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger("dsaver")
|
| 11 |
+
if not logger.handlers:
|
| 12 |
+
logger.setLevel(logging.INFO)
|
| 13 |
+
fmt = logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s")
|
| 14 |
+
handler = logging.StreamHandler()
|
| 15 |
+
handler.setFormatter(fmt)
|
| 16 |
+
logger.addHandler(handler)
|
| 17 |
+
|
| 18 |
+
class DriveSaver:
|
| 19 |
+
"""Google Drive uploader. Prefers OAuth; optional SA fallback (Shared Drive only)."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, default_folder_id: Optional[str] = None):
|
| 22 |
+
self.service = None
|
| 23 |
+
self.folder_id = default_folder_id or os.getenv("GDRIVE_FOLDER_ID")
|
| 24 |
+
self.supports_all_drives = os.getenv("GDRIVE_FOLDER_IS_SHARED", "false").lower() in ("1","true","yes")
|
| 25 |
+
self.allow_sa_fallback = os.getenv("GDRIVE_ALLOW_SA_FALLBACK", "false").lower() in ("1","true","yes")
|
| 26 |
+
if not self.folder_id:
|
| 27 |
+
logger.warning("📁 No GDRIVE_FOLDER_ID set; uploads must provide folder_id explicitly")
|
| 28 |
+
self._initialize_service()
|
| 29 |
+
|
| 30 |
+
def _initialize_service(self):
|
| 31 |
+
creds = get_credentials()
|
| 32 |
+
if creds:
|
| 33 |
+
logger.info("✅ Using OAuth credentials")
|
| 34 |
+
else:
|
| 35 |
+
# Optional SA fallback — ONLY valid for Shared Drives where SA is a member
|
| 36 |
+
if self.allow_sa_fallback:
|
| 37 |
+
creds_env = os.getenv("GDRIVE_CREDENTIALS_JSON")
|
| 38 |
+
if creds_env:
|
| 39 |
+
try:
|
| 40 |
+
info = json.loads(creds_env)
|
| 41 |
+
if info.get("type") == "service_account":
|
| 42 |
+
creds = service_account.Credentials.from_service_account_info(
|
| 43 |
+
info, scopes=["https://www.googleapis.com/auth/drive"]
|
| 44 |
+
)
|
| 45 |
+
logger.info("✅ Using Service Account credentials (fallback)")
|
| 46 |
+
if not self.supports_all_drives:
|
| 47 |
+
logger.warning("⚠️ SA fallback without Shared Drive mode will likely fail (no quota). "
|
| 48 |
+
"Set GDRIVE_FOLDER_IS_SHARED=true and use a Shared Drive folder ID.")
|
| 49 |
+
else:
|
| 50 |
+
logger.error("❌ GDRIVE_CREDENTIALS_JSON is not a service account JSON")
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error(f"❌ Failed to init Service Account: {e}")
|
| 53 |
+
if not creds:
|
| 54 |
+
logger.error("❌ No valid Google credentials available (OAuth or SA).")
|
| 55 |
+
self.service = None
|
| 56 |
+
return
|
| 57 |
+
# Build Drive service
|
| 58 |
+
self.service = build("drive", "v3", credentials=creds)
|
| 59 |
+
logger.info("✅ Google Drive service initialized")
|
| 60 |
+
|
| 61 |
+
def upload_file_to_drive(self, file_path: str, folder_id: Optional[str] = None, mimetype: Optional[str] = None) -> bool:
|
| 62 |
+
if not self.service:
|
| 63 |
+
logger.error("❌ Drive service not initialized")
|
| 64 |
+
return False
|
| 65 |
+
try:
|
| 66 |
+
target_folder = folder_id or self.folder_id
|
| 67 |
+
name = os.path.basename(file_path)
|
| 68 |
+
media = MediaFileUpload(file_path, mimetype=mimetype or "application/octet-stream")
|
| 69 |
+
metadata = {"name": name, "parents": [target_folder]}
|
| 70 |
+
req = self.service.files().create(
|
| 71 |
+
body=metadata,
|
| 72 |
+
media_body=media,
|
| 73 |
+
fields="id",
|
| 74 |
+
supportsAllDrives=self.supports_all_drives
|
| 75 |
+
)
|
| 76 |
+
req.execute()
|
| 77 |
+
logger.info(f"✅ Uploaded '{name}' to Drive (folder: {target_folder})")
|
| 78 |
+
return True
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logger.error(f"❌ Drive upload failed: {e}")
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
def is_service_available(self) -> bool:
|
| 84 |
+
return self.service is not None
|
| 85 |
+
|
| 86 |
+
def set_folder_id(self, folder_id: str):
|
| 87 |
+
self.folder_id = folder_id
|
| 88 |
+
logger.info(f"📁 Default folder ID updated: {folder_id}")
|
utils/llm.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Round-robin rotator + paraphrasing + translation/backtranslation
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import requests
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from google import genai
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger("llm")
|
| 9 |
+
if not logger.handlers:
|
| 10 |
+
logger.setLevel(logging.INFO)
|
| 11 |
+
handler = logging.StreamHandler()
|
| 12 |
+
logger.addHandler(handler)
|
| 13 |
+
|
| 14 |
+
# LLM parser limit text to log-out
|
| 15 |
+
def snip(s: str, n: int = 12) -> str:
|
| 16 |
+
if not isinstance(s, str): return "∅"
|
| 17 |
+
parts = s.strip().split()
|
| 18 |
+
return " ".join(parts[:n]) + (" …" if len(parts) > n else "")
|
| 19 |
+
|
| 20 |
+
class KeyRotator:
|
| 21 |
+
def __init__(self, env_prefix: str, max_keys: int = 5):
|
| 22 |
+
keys = []
|
| 23 |
+
for i in range(1, max_keys + 1):
|
| 24 |
+
v = os.getenv(f"{env_prefix}_{i}")
|
| 25 |
+
if v:
|
| 26 |
+
keys.append(v.strip())
|
| 27 |
+
if not keys:
|
| 28 |
+
logger.warning(f"[LLM] No keys found for prefix {env_prefix}_*")
|
| 29 |
+
self.keys = keys
|
| 30 |
+
self.dead = set()
|
| 31 |
+
self.idx = 0
|
| 32 |
+
|
| 33 |
+
def next_key(self) -> Optional[str]:
|
| 34 |
+
if not self.keys:
|
| 35 |
+
return None
|
| 36 |
+
for _ in range(len(self.keys)):
|
| 37 |
+
k = self.keys[self.idx % len(self.keys)]
|
| 38 |
+
self.idx += 1
|
| 39 |
+
if k not in self.dead:
|
| 40 |
+
return k
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
def mark_bad(self, key: Optional[str]):
|
| 44 |
+
if key:
|
| 45 |
+
self.dead.add(key)
|
| 46 |
+
logger.warning(f"[LLM] Quarantined key (prefix hidden): {key[:6]}***")
|
| 47 |
+
|
| 48 |
+
class GeminiClient:
|
| 49 |
+
def __init__(self, rotator: KeyRotator, default_model: str):
|
| 50 |
+
self.rotator = rotator
|
| 51 |
+
self.default_model = default_model
|
| 52 |
+
|
| 53 |
+
def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_output_tokens: int = 512) -> Optional[str]:
|
| 54 |
+
key = self.rotator.next_key()
|
| 55 |
+
if not key:
|
| 56 |
+
return None
|
| 57 |
+
try:
|
| 58 |
+
client = genai.Client(api_key=key)
|
| 59 |
+
# NOTE: matches your required pattern/use
|
| 60 |
+
res = client.models.generate_content(
|
| 61 |
+
model=model or self.default_model,
|
| 62 |
+
contents=prompt
|
| 63 |
+
)
|
| 64 |
+
text = getattr(res, "text", None)
|
| 65 |
+
if text:
|
| 66 |
+
logger.info(f"[LLM][Gemini] out={snip(text)}")
|
| 67 |
+
return text
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"[LLM][Gemini] {e}")
|
| 70 |
+
self.rotator.mark_bad(key)
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
class NvidiaClient:
|
| 74 |
+
def __init__(self, rotator: KeyRotator, default_model: str):
|
| 75 |
+
self.rotator = rotator
|
| 76 |
+
self.default_model = default_model
|
| 77 |
+
self.url = os.getenv("NVIDIA_API_URL", "https://integrate.api.nvidia.com/v1/chat/completions")
|
| 78 |
+
|
| 79 |
+
# Regex-based cleaning resp from quotes
|
| 80 |
+
def _clean_resp(self, resp: str) -> str:
|
| 81 |
+
if not resp: return resp
|
| 82 |
+
txt = resp.strip()
|
| 83 |
+
# Remove common boilerplate prefixes
|
| 84 |
+
for pat in [
|
| 85 |
+
r"^Here is (a|the) .*?:\s*",
|
| 86 |
+
r"^Paraphrased(?: version)?:\s*",
|
| 87 |
+
r"^Sure[,.]?\s*",
|
| 88 |
+
r"^Okay[,.]?\s*"
|
| 89 |
+
]:
|
| 90 |
+
import re
|
| 91 |
+
txt = re.sub(pat, "", txt, flags=re.I)
|
| 92 |
+
return txt.strip()
|
| 93 |
+
|
| 94 |
+
def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_tokens: int = 512) -> Optional[str]:
|
| 95 |
+
key = self.rotator.next_key()
|
| 96 |
+
if not key:
|
| 97 |
+
return None
|
| 98 |
+
try:
|
| 99 |
+
headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
|
| 100 |
+
payload = {
|
| 101 |
+
"model": model or self.default_model,
|
| 102 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 103 |
+
"temperature": temperature,
|
| 104 |
+
"max_tokens": max_tokens
|
| 105 |
+
}
|
| 106 |
+
r = requests.post(self.url, headers=headers, json=payload, timeout=45)
|
| 107 |
+
if r.status_code >= 400:
|
| 108 |
+
raise RuntimeError(f"HTTP {r.status_code}: {r.text[:200]}")
|
| 109 |
+
data = r.json()
|
| 110 |
+
text = data["choices"][0]["message"]["content"]
|
| 111 |
+
clean = self._clean_resp(text)
|
| 112 |
+
logger.info(f"[LLM][NVIDIA] out={snip(clean)}")
|
| 113 |
+
return clean
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logger.error(f"[LLM][NVIDIA] {e}")
|
| 116 |
+
self.rotator.mark_bad(key)
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
class Paraphraser:
|
| 120 |
+
"""Prefers NVIDIA (cheap), falls back to Gemini. Also offers translate/backtranslate and a tiny consistency judge."""
|
| 121 |
+
def __init__(self, nvidia_model: str, gemini_model_easy: str, gemini_model_hard: str):
|
| 122 |
+
self.nv = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
|
| 123 |
+
self.gm_easy = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_easy)
|
| 124 |
+
self.gm_hard = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_hard)
|
| 125 |
+
|
| 126 |
+
# Regex-based cleaning resp from quotes
|
| 127 |
+
def _clean_resp(self, resp: str) -> str:
|
| 128 |
+
if not resp: return resp
|
| 129 |
+
txt = resp.strip()
|
| 130 |
+
# Remove common boilerplate prefixes
|
| 131 |
+
for pat in [
|
| 132 |
+
r"^Here is (a|the) .*?:\s*",
|
| 133 |
+
r"^Paraphrased(?: version)?:\s*",
|
| 134 |
+
r"^Sure[,.]?\s*",
|
| 135 |
+
r"^Okay[,.]?\s*"
|
| 136 |
+
]:
|
| 137 |
+
import re
|
| 138 |
+
txt = re.sub(pat, "", txt, flags=re.I)
|
| 139 |
+
return txt.strip()
|
| 140 |
+
|
| 141 |
+
# ————— Paraphrase —————
|
| 142 |
+
def paraphrase(self, text: str, difficulty: str = "easy") -> str:
|
| 143 |
+
if not text or len(text) < 12:
|
| 144 |
+
return text
|
| 145 |
+
prompt = (
|
| 146 |
+
"Paraphrase the following medical text concisely, preserve meaning and clinical terms.\n"
|
| 147 |
+
"Do not fabricate or remove factual claims.\n"
|
| 148 |
+
"Return ONLY the rewritten text, without any introduction, commentary.\n"+ text
|
| 149 |
+
)
|
| 150 |
+
out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(600, max(128, len(text)//2)))
|
| 151 |
+
if out: return self._clean_resp(out)
|
| 152 |
+
gm = self.gm_easy if difficulty == "easy" else self.gm_hard
|
| 153 |
+
out = gm.generate(prompt, max_output_tokens=min(600, max(128, len(text)//2)))
|
| 154 |
+
return self._clean_resp(out) if out else text
|
| 155 |
+
|
| 156 |
+
# ————— Translate & Backtranslate —————
|
| 157 |
+
def translate(self, text: str, target_lang: str = "de") -> Optional[str]:
|
| 158 |
+
if not text: return text
|
| 159 |
+
prompt = f"Translate to {target_lang}. Keep meaning exact, preserve medical terms:\n\n{text}"
|
| 160 |
+
out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
|
| 161 |
+
if out: return out.strip()
|
| 162 |
+
return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
|
| 163 |
+
|
| 164 |
+
def backtranslate(self, text: str, via_lang: str = "de") -> Optional[str]:
|
| 165 |
+
if not text: return text
|
| 166 |
+
mid = self.translate(text, target_lang=via_lang)
|
| 167 |
+
if not mid: return None
|
| 168 |
+
prompt = f"Translate the following {via_lang} text back to English, preserving the exact meaning:\n\n{mid}"
|
| 169 |
+
out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
|
| 170 |
+
if out: return out.strip()
|
| 171 |
+
res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
|
| 172 |
+
return res.strip() if res else None
|
| 173 |
+
|
| 174 |
+
# ————— Consistency Judge (cheap, ratio-based) —————
|
| 175 |
+
def consistency_check(self, user: str, output: str) -> bool:
|
| 176 |
+
"""Return True if 'output' appears supported by 'user' (context/question). Soft heuristic via LLM."""
|
| 177 |
+
prompt = (
|
| 178 |
+
"You are a strict medical QA validator. Given the USER input (question+context) "
|
| 179 |
+
"and the MODEL ANSWER, reply with exactly 'PASS' if the answer is supported and safe, "
|
| 180 |
+
"otherwise 'FAIL'. No extra text.\n\n"
|
| 181 |
+
f"USER:\n{user}\n\nANSWER:\n{output}"
|
| 182 |
+
)
|
| 183 |
+
out = self.nv.generate(prompt, temperature=0.0, max_tokens=3)
|
| 184 |
+
if not out:
|
| 185 |
+
out = self.gm_easy.generate(prompt, max_output_tokens=3)
|
| 186 |
+
return isinstance(out, str) and "PASS" in out.upper()
|
utils/processor.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataset-specific parsers + paraphrasing flow
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import hashlib
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Callable, Optional, Dict, Tuple
|
| 7 |
+
|
| 8 |
+
from utils.schema import sft_row
|
| 9 |
+
from utils import augment as A
|
| 10 |
+
|
| 11 |
+
# Logger
|
| 12 |
+
logger = logging.getLogger("processor")
|
| 13 |
+
if not logger.handlers:
|
| 14 |
+
logger.setLevel(logging.INFO)
|
| 15 |
+
logger.addHandler(logging.StreamHandler())
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _hash_id(*parts) -> str:
|
| 19 |
+
h = hashlib.sha256()
|
| 20 |
+
for p in parts:
|
| 21 |
+
h.update(str(p).encode("utf-8"))
|
| 22 |
+
return h.hexdigest()[:16]
|
| 23 |
+
|
| 24 |
+
def _iter_json_or_jsonl(path: str):
|
| 25 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 26 |
+
first = f.read(1); f.seek(0)
|
| 27 |
+
if first == "[":
|
| 28 |
+
data = json.load(f)
|
| 29 |
+
for obj in data: yield obj
|
| 30 |
+
else:
|
| 31 |
+
for line in f:
|
| 32 |
+
line = line.strip()
|
| 33 |
+
if line: yield json.loads(line)
|
| 34 |
+
|
| 35 |
+
def process_file_into_sft(
|
| 36 |
+
dataset_key: str,
|
| 37 |
+
input_path: str,
|
| 38 |
+
writer,
|
| 39 |
+
paraphraser,
|
| 40 |
+
augment_opts: Dict,
|
| 41 |
+
sample_limit: Optional[int],
|
| 42 |
+
seed: int,
|
| 43 |
+
progress_cb: Optional[Callable[[float, str], None]]
|
| 44 |
+
) -> Tuple[int, Dict]:
|
| 45 |
+
random.seed(seed)
|
| 46 |
+
stats = {
|
| 47 |
+
"written": 0,
|
| 48 |
+
"paraphrased_input": 0,
|
| 49 |
+
"paraphrased_output": 0,
|
| 50 |
+
"backtranslated_input": 0,
|
| 51 |
+
"backtranslated_output": 0,
|
| 52 |
+
"dedup_skipped": 0,
|
| 53 |
+
"consistency_failed": 0
|
| 54 |
+
}
|
| 55 |
+
# Start processing SFT
|
| 56 |
+
key_summary = {k: augment_opts.get(k) for k in (
|
| 57 |
+
"paraphrase_ratio","backtranslate_ratio","paraphrase_outputs",
|
| 58 |
+
"style_standardize","deidentify","dedupe",
|
| 59 |
+
"consistency_check_ratio","distill_fraction"
|
| 60 |
+
)}
|
| 61 |
+
logger.info(
|
| 62 |
+
f"[PROC] Begin dataset={dataset_key} sample_limit={sample_limit} opts={key_summary}"
|
| 63 |
+
)
|
| 64 |
+
# If deduplicating enabled
|
| 65 |
+
dedupe_seen = set() if augment_opts.get("dedupe", True) else None
|
| 66 |
+
|
| 67 |
+
key = dataset_key.lower()
|
| 68 |
+
if key in ("healthcaremagic", "icliniq"):
|
| 69 |
+
count = _proc_med_dialog(source=key, path=input_path, writer=writer,
|
| 70 |
+
paraphraser=paraphraser, opts=augment_opts,
|
| 71 |
+
sample_limit=sample_limit, stats=stats, cb=progress_cb, dedupe_seen=dedupe_seen)
|
| 72 |
+
elif key == "pubmedqa_l":
|
| 73 |
+
count = _proc_pubmedqa_l(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen)
|
| 74 |
+
elif key == "pubmedqa_u":
|
| 75 |
+
count = _proc_pubmedqa_u(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen)
|
| 76 |
+
elif key == "pubmedqa_map":
|
| 77 |
+
count = _proc_pubmedqa_map(input_path, writer, paraphraser, augment_opts, sample_limit, stats, progress_cb, dedupe_seen=dedupe_seen)
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError(f"Unknown dataset: {dataset_key}")
|
| 80 |
+
logger.info(f"[PROC] End dataset={dataset_key} stats={stats}")
|
| 81 |
+
return count, stats
|
| 82 |
+
|
| 83 |
+
# ——————————— helpers ———————————
|
| 84 |
+
def _build_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict):
|
| 85 |
+
"""Return a list of (user_variant, out_variant, applied_tags) not including the original."""
|
| 86 |
+
variants = []
|
| 87 |
+
max_k = max(0, int(opts.get("max_aug_per_sample", 1)))
|
| 88 |
+
for _ in range(max_k):
|
| 89 |
+
applied = []
|
| 90 |
+
u2, did_p = A.maybe_paraphrase(user, opts.get("paraphrase_ratio", 0.0), paraphraser, "easy")
|
| 91 |
+
if did_p: applied.append("paraphrase_input"); stats["paraphrased_input"] += 1
|
| 92 |
+
u3, did_bt = A.maybe_backtranslate(u2, opts.get("backtranslate_ratio", 0.0), paraphraser)
|
| 93 |
+
if did_bt: applied.append("backtranslate_input"); stats["backtranslated_input"] += 1
|
| 94 |
+
|
| 95 |
+
o3 = out
|
| 96 |
+
if opts.get("paraphrase_outputs", False):
|
| 97 |
+
o2, did_p2 = A.maybe_paraphrase(out, opts.get("paraphrase_ratio", 0.0), paraphraser, "hard")
|
| 98 |
+
if did_p2: applied.append("paraphrase_output"); stats["paraphrased_output"] += 1
|
| 99 |
+
o3b, did_bt2 = A.maybe_backtranslate(o2, opts.get("backtranslate_ratio", 0.0), paraphraser)
|
| 100 |
+
if did_bt2: applied.append("backtranslate_output"); stats["backtranslated_output"] += 1
|
| 101 |
+
o3 = o3b
|
| 102 |
+
|
| 103 |
+
# If nothing applied, skip this variant
|
| 104 |
+
if not applied:
|
| 105 |
+
continue
|
| 106 |
+
# Style standardize and punctuation for the variant too
|
| 107 |
+
if opts.get("style_standardize", True):
|
| 108 |
+
o3 = A.style_standardize_answer(o3)
|
| 109 |
+
u3 = A.ensure_terminal_punct(u3) if u3 else u3
|
| 110 |
+
o3 = A.ensure_terminal_punct(o3) if o3 else o3
|
| 111 |
+
variants.append((u3, o3, applied))
|
| 112 |
+
return variants
|
| 113 |
+
|
| 114 |
+
def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphraser, stats: Dict):
|
| 115 |
+
# Base cleanup & caps (returns cleaned strings)
|
| 116 |
+
user = A.base_cleanup(user, opts.get("max_chars", 5000), opts.get("deidentify", True))
|
| 117 |
+
out = A.base_cleanup(out, opts.get("max_chars", 5000), opts.get("deidentify", True))
|
| 118 |
+
instr = A.base_cleanup(instr, opts.get("max_chars", 5000), False)
|
| 119 |
+
|
| 120 |
+
# Language sanity (mostly English—skip aggressive transforms if not)
|
| 121 |
+
if not A.lang_is_english(user): # very rare
|
| 122 |
+
return instr, user, out, []
|
| 123 |
+
|
| 124 |
+
# Stack list of entries that has been applied augmentation and stylings
|
| 125 |
+
applied = []
|
| 126 |
+
|
| 127 |
+
# Style standardizing the answer
|
| 128 |
+
if opts.get("style_standardize", True):
|
| 129 |
+
out = A.style_standardize_answer(out)
|
| 130 |
+
applied.append("style_standardize")
|
| 131 |
+
|
| 132 |
+
# Ensure punctuation/whitespace
|
| 133 |
+
user = A.ensure_terminal_punct(user) if user else user
|
| 134 |
+
out = A.ensure_terminal_punct(out) if out else out
|
| 135 |
+
|
| 136 |
+
return instr, user, out, applied
|
| 137 |
+
|
| 138 |
+
def _commit_row(writer, source, rid, task, instr, user, out, opts, stats, aug_applied, extra_meta=None, dedupe_seen=None):
|
| 139 |
+
# Dedup entry
|
| 140 |
+
if dedupe_seen is not None:
|
| 141 |
+
fp = A.fingerprint(instr, user, out)
|
| 142 |
+
if fp in dedupe_seen:
|
| 143 |
+
stats["dedup_skipped"] += 1
|
| 144 |
+
return False
|
| 145 |
+
dedupe_seen.add(fp)
|
| 146 |
+
|
| 147 |
+
meta = {"augmentations": aug_applied}
|
| 148 |
+
if extra_meta:
|
| 149 |
+
meta.update(extra_meta)
|
| 150 |
+
|
| 151 |
+
row = sft_row(instr, user, out, source=source, rid=rid, task=task, meta=meta)
|
| 152 |
+
writer.write(row)
|
| 153 |
+
stats["written"] += 1
|
| 154 |
+
return True
|
| 155 |
+
|
| 156 |
+
# ——————————— dataset processors ———————————
|
| 157 |
+
|
| 158 |
+
def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
|
| 159 |
+
count = 0
|
| 160 |
+
written = 0
|
| 161 |
+
for i, obj in enumerate(_iter_json_or_jsonl(path), start=1):
|
| 162 |
+
try:
|
| 163 |
+
instr_raw = obj.get("instruction") or "Answer the patient's question like a clinician. Be concise and safe."
|
| 164 |
+
user_raw = obj.get("input") or ""
|
| 165 |
+
out_raw = obj.get("output") or ""
|
| 166 |
+
|
| 167 |
+
# Ensure we have string values
|
| 168 |
+
instr = str(instr_raw).strip()
|
| 169 |
+
user = str(user_raw).strip()
|
| 170 |
+
out = str(out_raw).strip()
|
| 171 |
+
rid = _hash_id(source, i, len(user), len(out))
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.warning(f"[PROC] {source} error processing item {i}: {e}, item: {obj}")
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
instr, user, out, applied = _apply_aug(instr, user, out, source, opts, paraphraser, stats)
|
| 178 |
+
|
| 179 |
+
# 1) ALWAYS write the original (cleaned/style-standardised only)
|
| 180 |
+
# Optional consistency spot-check (cheap)
|
| 181 |
+
if not A.consistency_ok(user, out, opts.get("consistency_check_ratio", 0.0), paraphraser):
|
| 182 |
+
stats["consistency_failed"] += 1
|
| 183 |
+
# keep the sample but tag it
|
| 184 |
+
applied.append("consistency_flag")
|
| 185 |
+
|
| 186 |
+
# 2) If expansion is enabled, add augmented copies
|
| 187 |
+
_commit_row(writer, source, rid, "medical_dialogue", instr, user, out, opts, stats, ["base"] + applied, dedupe_seen=dedupe_seen)
|
| 188 |
+
# Add augmented copies if expand
|
| 189 |
+
if opts.get("expand", True):
|
| 190 |
+
for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
|
| 191 |
+
rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
|
| 192 |
+
_commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
|
| 193 |
+
|
| 194 |
+
# Increment count only on success
|
| 195 |
+
count += 1
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.warning(f"[PROC] {source} error in processing/augmentation for item {i}: {e}")
|
| 198 |
+
continue
|
| 199 |
+
if sample_limit and count >= sample_limit:
|
| 200 |
+
break
|
| 201 |
+
if cb and i % 1000 == 0:
|
| 202 |
+
cb(min(0.9, 0.05 + i/200000), f"{source}: processed {i} rows")
|
| 203 |
+
if cb:
|
| 204 |
+
cb(0.92, f"{source} done ({count})")
|
| 205 |
+
logger.info(f"[PROC] {source} done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
|
| 206 |
+
return count
|
| 207 |
+
|
| 208 |
+
def _proc_pubmedqa_l(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
|
| 209 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 210 |
+
data = json.load(f)
|
| 211 |
+
count = 0
|
| 212 |
+
for k, v in data.items():
|
| 213 |
+
try:
|
| 214 |
+
q_raw = v.get("QUESTION") or ""
|
| 215 |
+
ctx_list = v.get("CONTEXTS") or []
|
| 216 |
+
long_ans_raw = v.get("LONG_ANSWER") or ""
|
| 217 |
+
final_raw = v.get("final_decision") or ""
|
| 218 |
+
|
| 219 |
+
# Ensure we have string values
|
| 220 |
+
q = str(q_raw).strip() if q_raw else ""
|
| 221 |
+
if isinstance(ctx_list, list):
|
| 222 |
+
context = "\n".join(str(ctx) for ctx in ctx_list).strip()
|
| 223 |
+
else:
|
| 224 |
+
context = str(ctx_list).strip()
|
| 225 |
+
long_ans = str(long_ans_raw).strip() if long_ans_raw else ""
|
| 226 |
+
final = str(final_raw).strip() if final_raw else ""
|
| 227 |
+
except Exception as e:
|
| 228 |
+
logger.warning(f"[PROC] pubmedqa_l error processing item {k}: {e}, item: {v}")
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
instr = "Answer the biomedical question using the provided context. Include a concise rationale if possible."
|
| 233 |
+
user = f"Question: {q}\n\nContext:\n{context}" if context else f"Question: {q}"
|
| 234 |
+
out = long_ans if long_ans else final
|
| 235 |
+
rid = str(k)
|
| 236 |
+
|
| 237 |
+
instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_l", opts, paraphraser, stats)
|
| 238 |
+
_commit_row(writer, "pubmedqa_l", rid, "biomedical_qa", instr, user, out, opts, stats, applied,
|
| 239 |
+
extra_meta={"year": v.get("YEAR"), "meshes": v.get("MESHES"), "labels": v.get("LABELS")}, dedupe_seen=dedupe_seen)
|
| 240 |
+
if opts.get("expand", True):
|
| 241 |
+
for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
|
| 242 |
+
rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
|
| 243 |
+
_commit_row(writer, "pubmedqa_l", rid_aug, "biomedical_qa",
|
| 244 |
+
instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
|
| 245 |
+
|
| 246 |
+
# Increment count only on success
|
| 247 |
+
count += 1
|
| 248 |
+
except Exception as e:
|
| 249 |
+
logger.warning(f"[PROC] pubmedqa_l error in processing/augmentation for item {k}: {e}")
|
| 250 |
+
continue
|
| 251 |
+
if sample_limit and count >= sample_limit:
|
| 252 |
+
break
|
| 253 |
+
if cb and count % 1000 == 0:
|
| 254 |
+
cb(min(0.9, 0.05 + count/60000), f"pubmedqa_l processed {count}")
|
| 255 |
+
if cb:
|
| 256 |
+
cb(0.93, f"pubmedqa_l done ({count})")
|
| 257 |
+
logger.info(f"[PROC] pubmedqa_l done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
|
| 258 |
+
return count
|
| 259 |
+
|
| 260 |
+
def _proc_pubmedqa_u(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
|
| 261 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 262 |
+
data = json.load(f)
|
| 263 |
+
count = 0
|
| 264 |
+
for k, v in data.items():
|
| 265 |
+
try:
|
| 266 |
+
q_raw = v.get("QUESTION") or ""
|
| 267 |
+
ctx_list = v.get("CONTEXTS") or []
|
| 268 |
+
|
| 269 |
+
# Ensure we have string values
|
| 270 |
+
q = str(q_raw).strip() if q_raw else ""
|
| 271 |
+
if isinstance(ctx_list, list):
|
| 272 |
+
context = "\n".join(str(ctx) for ctx in ctx_list).strip()
|
| 273 |
+
else:
|
| 274 |
+
context = str(ctx_list).strip()
|
| 275 |
+
except Exception as e:
|
| 276 |
+
logger.warning(f"[PROC] pubmedqa_u error processing item {k}: {e}, item: {v}")
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
try:
|
| 280 |
+
instr = "Rewrite the context into a succinct note, then answer the question. If unknown, say 'insufficient evidence'."
|
| 281 |
+
user = f"Question: {q}\n\nContext:\n{context}" if context else f"Question: {q}"
|
| 282 |
+
out = "" # unlabeled
|
| 283 |
+
rid = str(k)
|
| 284 |
+
|
| 285 |
+
# Optional KD/distillation for a small fraction
|
| 286 |
+
if opts.get("distill_fraction", 0.0) > 0.0 and random.random() < float(opts["distill_fraction"]):
|
| 287 |
+
prompt = f"{instr}\n\n{user}\n\nAnswer briefly and safely."
|
| 288 |
+
guess = paraphraser.paraphrase(prompt, difficulty="hard") # cheap single call
|
| 289 |
+
if guess and len(guess) < 2000:
|
| 290 |
+
out = guess.strip()
|
| 291 |
+
|
| 292 |
+
instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_u", opts, paraphraser, stats)
|
| 293 |
+
_commit_row(writer, "pubmedqa_u", str(k), "biomedical_qa_unlabeled", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen)
|
| 294 |
+
if opts.get("expand", True):
|
| 295 |
+
for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
|
| 296 |
+
rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
|
| 297 |
+
_commit_row(writer, "pubmedqa_u", rid_aug, "biomedical_qa",
|
| 298 |
+
instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
|
| 299 |
+
|
| 300 |
+
# Increment count only on success
|
| 301 |
+
count += 1
|
| 302 |
+
except Exception as e:
|
| 303 |
+
logger.warning(f"[PROC] pubmedqa_u error in processing/augmentation for item {k}: {e}")
|
| 304 |
+
continue
|
| 305 |
+
if sample_limit and count >= sample_limit:
|
| 306 |
+
break
|
| 307 |
+
if cb and count % 2000 == 0:
|
| 308 |
+
cb(min(0.9, 0.05 + count/80000), f"pubmedqa_u processed {count}")
|
| 309 |
+
if cb:
|
| 310 |
+
cb(0.94, f"pubmedqa_u done ({count})")
|
| 311 |
+
logger.info(f"[PROC] pubmedqa_u done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
|
| 312 |
+
return count
|
| 313 |
+
|
| 314 |
+
def _proc_pubmedqa_map(path, writer, paraphraser, opts, sample_limit, stats, cb, dedupe_seen=None):
|
| 315 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 316 |
+
obj = json.load(f)
|
| 317 |
+
|
| 318 |
+
# Log the structure for debugging
|
| 319 |
+
logger.info(f"[PROC] pubmedqa_map data type: {type(obj)}")
|
| 320 |
+
if isinstance(obj, dict):
|
| 321 |
+
logger.info(f"[PROC] pubmedqa_map dict keys: {list(obj.keys())}")
|
| 322 |
+
if len(obj) > 0:
|
| 323 |
+
sample_key = next(iter(obj.keys()))
|
| 324 |
+
sample_value = obj[sample_key]
|
| 325 |
+
logger.info(f"[PROC] pubmedqa_map sample value type: {type(sample_value)}")
|
| 326 |
+
if isinstance(sample_value, dict):
|
| 327 |
+
logger.info(f"[PROC] pubmedqa_map sample value keys: {list(sample_value.keys())}")
|
| 328 |
+
|
| 329 |
+
# Iteration of items
|
| 330 |
+
def iter_items():
|
| 331 |
+
try:
|
| 332 |
+
if isinstance(obj, list):
|
| 333 |
+
for it in obj:
|
| 334 |
+
if isinstance(it, dict):
|
| 335 |
+
yield it
|
| 336 |
+
else:
|
| 337 |
+
logger.warning(f"[PROC] pubmedqa_map skipping non-dict list item: {type(it)}")
|
| 338 |
+
elif isinstance(obj, dict):
|
| 339 |
+
qs, cs, ans = obj.get("question"), obj.get("context"), obj.get("answer")
|
| 340 |
+
if isinstance(qs, list) and isinstance(cs, list) and isinstance(ans, list):
|
| 341 |
+
for i in range(min(len(qs), len(cs), len(ans))):
|
| 342 |
+
yield {"question": qs[i], "context": cs[i], "answer": ans[i]}
|
| 343 |
+
else:
|
| 344 |
+
# Handle case where values might be dictionaries or other objects
|
| 345 |
+
for k, v in obj.items():
|
| 346 |
+
if isinstance(v, dict):
|
| 347 |
+
# If v is a dict, ensure it has the expected structure
|
| 348 |
+
if "question" in v and "context" in v and "answer" in v:
|
| 349 |
+
yield v
|
| 350 |
+
else:
|
| 351 |
+
# Try to map the keys to expected structure
|
| 352 |
+
yield {
|
| 353 |
+
"question": v.get("question") or v.get("QUESTION") or str(k),
|
| 354 |
+
"context": v.get("context") or v.get("CONTEXT") or "",
|
| 355 |
+
"answer": v.get("answer") or v.get("ANSWER") or ""
|
| 356 |
+
}
|
| 357 |
+
else:
|
| 358 |
+
# If v is not a dict, create a simple structure
|
| 359 |
+
yield {"question": str(k), "context": str(v) if v else "", "answer": ""}
|
| 360 |
+
else:
|
| 361 |
+
logger.warning(f"[PROC] pubmedqa_map unexpected data type: {type(obj)}")
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.error(f"[PROC] pubmedqa_map error in iter_items: {e}")
|
| 364 |
+
return
|
| 365 |
+
|
| 366 |
+
count = 0
|
| 367 |
+
for i, v in enumerate(iter_items(), start=1):
|
| 368 |
+
try:
|
| 369 |
+
# Ensure we have string values, convert if necessary
|
| 370 |
+
q_raw = v.get("question") or ""
|
| 371 |
+
c_raw = v.get("context") or ""
|
| 372 |
+
a_raw = v.get("answer") or ""
|
| 373 |
+
|
| 374 |
+
# Convert to string if not already
|
| 375 |
+
q = str(q_raw).strip() if q_raw else ""
|
| 376 |
+
c = str(c_raw).strip() if c_raw else ""
|
| 377 |
+
a = str(a_raw).strip() if a_raw else ""
|
| 378 |
+
|
| 379 |
+
instr = "Answer the biomedical question based on the context. Justify briefly."
|
| 380 |
+
user = f"Question: {q}\n\nContext:\n{c}" if c else f"Question: {q}"
|
| 381 |
+
out = a
|
| 382 |
+
rid = _hash_id("pubmedqa_map", i, len(q))
|
| 383 |
+
|
| 384 |
+
# Process the item
|
| 385 |
+
instr, user, out, applied = _apply_aug(instr, user, out, "pubmedqa_map", opts, paraphraser, stats)
|
| 386 |
+
_commit_row(writer, "pubmedqa_map", rid, "biomedical_qa", instr, user, out, opts, stats, applied, dedupe_seen=dedupe_seen)
|
| 387 |
+
|
| 388 |
+
# Handle expansion if enabled
|
| 389 |
+
if opts.get("expand", True):
|
| 390 |
+
for (u_aug, o_aug, aug_tags) in _build_variants(user, out, paraphraser, opts, stats):
|
| 391 |
+
rid_aug = f"{rid}-aug{random.randint(1000,9999)}"
|
| 392 |
+
_commit_row(writer, "pubmedqa_map", rid_aug, "biomedical_qa",
|
| 393 |
+
instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen)
|
| 394 |
+
|
| 395 |
+
# Increment count only on success
|
| 396 |
+
count += 1
|
| 397 |
+
|
| 398 |
+
except Exception as e:
|
| 399 |
+
logger.warning(f"[PROC] pubmedqa_map error processing item {i}: {e}, item: {v}")
|
| 400 |
+
continue
|
| 401 |
+
|
| 402 |
+
# Check sample limit
|
| 403 |
+
if sample_limit and count >= sample_limit:
|
| 404 |
+
break
|
| 405 |
+
if cb and i % 2000 == 0:
|
| 406 |
+
cb(min(0.9, 0.05 + i/120000), f"pubmedqa_map processed {i}")
|
| 407 |
+
|
| 408 |
+
if cb:
|
| 409 |
+
cb(0.95, f"pubmedqa_map done ({count})")
|
| 410 |
+
logger.info(f"[PROC] pubmedqa_map done count={count} written={stats['written']} dedup_skipped={stats['dedup_skipped']}")
|
| 411 |
+
return count
|
utils/rag.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RAG-specific dataset processor
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import hashlib
|
| 5 |
+
import random
|
| 6 |
+
from typing import Dict, List, Tuple, Optional, Callable
|
| 7 |
+
|
| 8 |
+
from utils.schema import sft_row
|
| 9 |
+
from utils.llm import NvidiaClient, KeyRotator
|
| 10 |
+
|
| 11 |
+
# Logger
|
| 12 |
+
logger = logging.getLogger("rag_processor")
|
| 13 |
+
if not logger.handlers:
|
| 14 |
+
logger.setLevel(logging.INFO)
|
| 15 |
+
logger.addHandler(logging.StreamHandler())
|
| 16 |
+
|
| 17 |
+
def _hash_id(*parts) -> str:
|
| 18 |
+
"""Generate a hash ID for RAG entries"""
|
| 19 |
+
h = hashlib.sha256()
|
| 20 |
+
for p in parts:
|
| 21 |
+
h.update(str(p).encode("utf-8"))
|
| 22 |
+
return h.hexdigest()[:16]
|
| 23 |
+
|
| 24 |
+
def _iter_json_or_jsonl(path: str):
|
| 25 |
+
"""Iterate over JSON or JSONL files"""
|
| 26 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 27 |
+
first = f.read(1)
|
| 28 |
+
f.seek(0)
|
| 29 |
+
if first == "[":
|
| 30 |
+
data = json.load(f)
|
| 31 |
+
for obj in data:
|
| 32 |
+
yield obj
|
| 33 |
+
else:
|
| 34 |
+
for line in f:
|
| 35 |
+
line = line.strip()
|
| 36 |
+
if line:
|
| 37 |
+
yield json.loads(line)
|
| 38 |
+
|
| 39 |
+
class RAGProcessor:
|
| 40 |
+
"""Processes medical datasets into RAG-specific QCA (Question, Context, Answer) format"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, nvidia_model: str):
|
| 43 |
+
self.nvidia_client = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
|
| 44 |
+
|
| 45 |
+
def clean_conversational_content(self, text: str) -> str:
|
| 46 |
+
"""Remove conversational elements and non-medical information using NVIDIA model"""
|
| 47 |
+
if not text or len(text.strip()) < 10:
|
| 48 |
+
return text
|
| 49 |
+
|
| 50 |
+
prompt = f"""
|
| 51 |
+
You are a medical data cleaning expert. Clean the following text by:
|
| 52 |
+
1. Remove conversational elements (greetings, pleasantries)
|
| 53 |
+
2. Remove non-medical small talk and social interactions
|
| 54 |
+
3. Keep only medically relevant information
|
| 55 |
+
4. Preserve clinical facts, symptoms, diagnoses, treatments, and medical advice
|
| 56 |
+
5. Maintain professional medical language
|
| 57 |
+
6. Return only cleaned medical content, only plain text, no special characters, or formatting.
|
| 58 |
+
|
| 59 |
+
Text to clean:
|
| 60 |
+
{text}
|
| 61 |
+
|
| 62 |
+
Cleaned medical content:"""
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
cleaned = self.nvidia_client.generate(
|
| 66 |
+
prompt,
|
| 67 |
+
temperature=0.1,
|
| 68 |
+
max_tokens=min(1000, len(text) + 200)
|
| 69 |
+
)
|
| 70 |
+
return cleaned.strip() if cleaned else text
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.warning(f"[RAG] Error cleaning text: {e}")
|
| 73 |
+
return text
|
| 74 |
+
|
| 75 |
+
def generate_context_from_qa(self, question: str, answer: str) -> str:
|
| 76 |
+
"""Generate synthetic context from question and answer using NVIDIA model"""
|
| 77 |
+
if not question or not answer:
|
| 78 |
+
return ""
|
| 79 |
+
|
| 80 |
+
prompt = f"""You are a medical knowledge expert. Given a medical question and its answer, generate a brief relevant medical context that would help someone understand the answer better. Write about 2 sentences that provide relevant background information. Use only plain text without any formatting or symbols.
|
| 81 |
+
|
| 82 |
+
Question: {question}
|
| 83 |
+
|
| 84 |
+
Answer: {answer}
|
| 85 |
+
|
| 86 |
+
Generate a concise medical context:"""
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
context = self.nvidia_client.generate(
|
| 90 |
+
prompt,
|
| 91 |
+
temperature=0.2,
|
| 92 |
+
max_tokens=200
|
| 93 |
+
)
|
| 94 |
+
return context.strip() if context else ""
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.warning(f"[RAG] Error generating context: {e}")
|
| 97 |
+
return ""
|
| 98 |
+
|
| 99 |
+
def convert_to_qca_format(self, instruction: str, user_input: str, output: str) -> Tuple[str, str, str]:
|
| 100 |
+
"""Convert SFT format to QCA (Question, Context, Answer) format"""
|
| 101 |
+
# Clean the content to remove conversational elements
|
| 102 |
+
cleaned_input = self.clean_conversational_content(user_input)
|
| 103 |
+
cleaned_output = self.clean_conversational_content(output)
|
| 104 |
+
|
| 105 |
+
# Extract question from user input
|
| 106 |
+
question = self.extract_question(cleaned_input)
|
| 107 |
+
|
| 108 |
+
# Extract or generate context
|
| 109 |
+
context = self.extract_context(cleaned_input, question, cleaned_output)
|
| 110 |
+
|
| 111 |
+
# Clean answer
|
| 112 |
+
answer = cleaned_output
|
| 113 |
+
|
| 114 |
+
return question, context, answer
|
| 115 |
+
|
| 116 |
+
def extract_question(self, user_input: str) -> str:
|
| 117 |
+
"""Extract the main question from user input"""
|
| 118 |
+
if not user_input:
|
| 119 |
+
return ""
|
| 120 |
+
|
| 121 |
+
# Try to identify question patterns
|
| 122 |
+
lines = user_input.split('\n')
|
| 123 |
+
for line in lines:
|
| 124 |
+
line = line.strip()
|
| 125 |
+
if line.startswith('Question:') or line.startswith('Q:'):
|
| 126 |
+
return line.replace('Question:', '').replace('Q:', '').strip()
|
| 127 |
+
elif '?' in line and len(line) > 10:
|
| 128 |
+
return line
|
| 129 |
+
|
| 130 |
+
# If no clear question found, use the first meaningful line
|
| 131 |
+
for line in lines:
|
| 132 |
+
line = line.strip()
|
| 133 |
+
if len(line) > 10:
|
| 134 |
+
return line
|
| 135 |
+
|
| 136 |
+
return user_input
|
| 137 |
+
|
| 138 |
+
def extract_context(self, user_input: str, question: str, answer: str) -> str:
|
| 139 |
+
"""Extract context from user input or generate synthetic context"""
|
| 140 |
+
# Look for context in the original input
|
| 141 |
+
context_candidates = []
|
| 142 |
+
lines = user_input.split('\n')
|
| 143 |
+
|
| 144 |
+
for line in lines:
|
| 145 |
+
line = line.strip()
|
| 146 |
+
if (line.startswith('Context:') or
|
| 147 |
+
line.startswith('Background:') or
|
| 148 |
+
line.startswith('Information:') or
|
| 149 |
+
(len(line) > 50 and not line.startswith('Question:') and '?' not in line)):
|
| 150 |
+
context_candidates.append(line)
|
| 151 |
+
|
| 152 |
+
if context_candidates:
|
| 153 |
+
# Clean and combine context candidates
|
| 154 |
+
context = ' '.join(context_candidates)
|
| 155 |
+
context = self.clean_conversational_content(context)
|
| 156 |
+
if len(context) > 20: # Ensure we have meaningful context
|
| 157 |
+
return context
|
| 158 |
+
|
| 159 |
+
# Generate synthetic context if none found
|
| 160 |
+
if question and answer:
|
| 161 |
+
synthetic_context = self.generate_context_from_qa(question, answer)
|
| 162 |
+
if synthetic_context:
|
| 163 |
+
return synthetic_context
|
| 164 |
+
|
| 165 |
+
return ""
|
| 166 |
+
|
| 167 |
+
def process_medical_dialog(self, source: str, path: str, writer, sample_limit: Optional[int],
|
| 168 |
+
stats: Dict, progress_cb: Optional[Callable], dedupe_seen: set = None) -> int:
|
| 169 |
+
"""Process medical dialogue datasets into RAG format"""
|
| 170 |
+
count = 0
|
| 171 |
+
written = 0
|
| 172 |
+
|
| 173 |
+
for i, obj in enumerate(_iter_json_or_jsonl(path), start=1):
|
| 174 |
+
try:
|
| 175 |
+
instr_raw = obj.get("instruction") or "Answer the medical question based on the provided context."
|
| 176 |
+
user_raw = obj.get("input") or ""
|
| 177 |
+
out_raw = obj.get("output") or ""
|
| 178 |
+
|
| 179 |
+
instr = str(instr_raw).strip()
|
| 180 |
+
user = str(user_raw).strip()
|
| 181 |
+
out = str(out_raw).strip()
|
| 182 |
+
rid = _hash_id(source, i, len(user), len(out))
|
| 183 |
+
|
| 184 |
+
# Convert to QCA format
|
| 185 |
+
question, context, answer = self.convert_to_qca_format(instr, user, out)
|
| 186 |
+
|
| 187 |
+
if not question or not answer:
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
# Create RAG-specific instruction
|
| 191 |
+
rag_instruction = "Answer the medical question based on the provided context. If the context is insufficient, provide the best available medical information."
|
| 192 |
+
|
| 193 |
+
# Format user input as QCA
|
| 194 |
+
if context:
|
| 195 |
+
rag_user = f"Question: {question}\n\nContext: {context}"
|
| 196 |
+
else:
|
| 197 |
+
rag_user = f"Question: {question}"
|
| 198 |
+
|
| 199 |
+
# Commit the RAG-formatted row
|
| 200 |
+
if self._commit_rag_row(writer, source, rid, "rag_medical_qa",
|
| 201 |
+
rag_instruction, rag_user, answer,
|
| 202 |
+
stats, dedupe_seen=dedupe_seen):
|
| 203 |
+
written += 1
|
| 204 |
+
|
| 205 |
+
count += 1
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.warning(f"[RAG] {source} error processing item {i}: {e}")
|
| 209 |
+
continue
|
| 210 |
+
|
| 211 |
+
if sample_limit and count >= sample_limit:
|
| 212 |
+
break
|
| 213 |
+
if progress_cb and i % 1000 == 0:
|
| 214 |
+
progress_cb(min(0.9, 0.05 + i/200000), f"{source}: processed {i} rows for RAG")
|
| 215 |
+
|
| 216 |
+
if progress_cb:
|
| 217 |
+
progress_cb(0.92, f"{source} RAG processing done ({count})")
|
| 218 |
+
|
| 219 |
+
logger.info(f"[RAG] {source} RAG processing done count={count} written={written}")
|
| 220 |
+
return count
|
| 221 |
+
|
| 222 |
+
def process_pubmedqa(self, source: str, path: str, writer, sample_limit: Optional[int],
|
| 223 |
+
stats: Dict, progress_cb: Optional[Callable], dedupe_seen: set = None) -> int:
|
| 224 |
+
"""Process PubMedQA datasets into RAG format"""
|
| 225 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 226 |
+
data = json.load(f)
|
| 227 |
+
|
| 228 |
+
count = 0
|
| 229 |
+
written = 0
|
| 230 |
+
|
| 231 |
+
for k, v in data.items():
|
| 232 |
+
try:
|
| 233 |
+
q_raw = v.get("QUESTION") or ""
|
| 234 |
+
ctx_list = v.get("CONTEXTS") or []
|
| 235 |
+
long_ans_raw = v.get("LONG_ANSWER") or ""
|
| 236 |
+
final_raw = v.get("final_decision") or ""
|
| 237 |
+
|
| 238 |
+
question = str(q_raw).strip() if q_raw else ""
|
| 239 |
+
if isinstance(ctx_list, list):
|
| 240 |
+
context = "\n".join(str(ctx) for ctx in ctx_list).strip()
|
| 241 |
+
else:
|
| 242 |
+
context = str(ctx_list).strip()
|
| 243 |
+
answer = str(long_ans_raw).strip() if long_ans_raw else str(final_raw).strip()
|
| 244 |
+
|
| 245 |
+
if not question or not answer:
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
# Clean the content
|
| 249 |
+
question = self.clean_conversational_content(question)
|
| 250 |
+
context = self.clean_conversational_content(context)
|
| 251 |
+
answer = self.clean_conversational_content(answer)
|
| 252 |
+
|
| 253 |
+
# Generate context if missing
|
| 254 |
+
if not context:
|
| 255 |
+
context = self.generate_context_from_qa(question, answer)
|
| 256 |
+
|
| 257 |
+
rid = str(k)
|
| 258 |
+
rag_instruction = "Answer the biomedical question based on the provided context."
|
| 259 |
+
|
| 260 |
+
if context:
|
| 261 |
+
rag_user = f"Question: {question}\n\nContext: {context}"
|
| 262 |
+
else:
|
| 263 |
+
rag_user = f"Question: {question}"
|
| 264 |
+
|
| 265 |
+
# Commit the RAG-formatted row
|
| 266 |
+
if self._commit_rag_row(writer, source, rid, "rag_biomedical_qa",
|
| 267 |
+
rag_instruction, rag_user, answer,
|
| 268 |
+
stats, dedupe_seen=dedupe_seen):
|
| 269 |
+
written += 1
|
| 270 |
+
|
| 271 |
+
count += 1
|
| 272 |
+
|
| 273 |
+
except Exception as e:
|
| 274 |
+
logger.warning(f"[RAG] {source} error processing item {k}: {e}")
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
if sample_limit and count >= sample_limit:
|
| 278 |
+
break
|
| 279 |
+
if progress_cb and count % 1000 == 0:
|
| 280 |
+
progress_cb(min(0.9, 0.05 + count/60000), f"{source} RAG processed {count}")
|
| 281 |
+
|
| 282 |
+
if progress_cb:
|
| 283 |
+
progress_cb(0.93, f"{source} RAG processing done ({count})")
|
| 284 |
+
|
| 285 |
+
logger.info(f"[RAG] {source} RAG processing done count={count} written={written}")
|
| 286 |
+
return count
|
| 287 |
+
|
| 288 |
+
def _commit_rag_row(self, writer, source: str, rid: str, task: str,
|
| 289 |
+
instruction: str, user_input: str, output: str,
|
| 290 |
+
stats: Dict, dedupe_seen: set = None) -> bool:
|
| 291 |
+
"""Commit a RAG-formatted row to the writer"""
|
| 292 |
+
# Simple deduplication based on content hash
|
| 293 |
+
if dedupe_seen is not None:
|
| 294 |
+
content_hash = hashlib.md5(f"{user_input}{output}".encode()).hexdigest()
|
| 295 |
+
if content_hash in dedupe_seen:
|
| 296 |
+
stats["dedup_skipped"] = stats.get("dedup_skipped", 0) + 1
|
| 297 |
+
return False
|
| 298 |
+
dedupe_seen.add(content_hash)
|
| 299 |
+
|
| 300 |
+
meta = {"rag_processing": True, "format": "qca"}
|
| 301 |
+
row = sft_row(instruction, user_input, output, source=source, rid=rid, task=task, meta=meta)
|
| 302 |
+
writer.write(row)
|
| 303 |
+
stats["written"] = stats.get("written", 0) + 1
|
| 304 |
+
return True
|
| 305 |
+
|
| 306 |
+
def process_file_into_rag(
|
| 307 |
+
dataset_key: str,
|
| 308 |
+
input_path: str,
|
| 309 |
+
writer,
|
| 310 |
+
nvidia_model: str,
|
| 311 |
+
sample_limit: Optional[int],
|
| 312 |
+
seed: int,
|
| 313 |
+
progress_cb: Optional[Callable[[float, str], None]]
|
| 314 |
+
) -> Tuple[int, Dict]:
|
| 315 |
+
"""Main entry point for RAG processing"""
|
| 316 |
+
random.seed(seed)
|
| 317 |
+
stats = {
|
| 318 |
+
"written": 0,
|
| 319 |
+
"dedup_skipped": 0
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
logger.info(f"[RAG] Begin RAG processing dataset={dataset_key} sample_limit={sample_limit}")
|
| 323 |
+
|
| 324 |
+
# Initialize RAG processor
|
| 325 |
+
rag_processor = RAGProcessor(nvidia_model)
|
| 326 |
+
dedupe_seen = set()
|
| 327 |
+
|
| 328 |
+
key = dataset_key.lower()
|
| 329 |
+
if key in ("healthcaremagic", "icliniq"):
|
| 330 |
+
count = rag_processor.process_medical_dialog(
|
| 331 |
+
source=key, path=input_path, writer=writer,
|
| 332 |
+
sample_limit=sample_limit, stats=stats,
|
| 333 |
+
progress_cb=progress_cb, dedupe_seen=dedupe_seen
|
| 334 |
+
)
|
| 335 |
+
elif key in ("pubmedqa_l", "pubmedqa_u", "pubmedqa_map"):
|
| 336 |
+
count = rag_processor.process_pubmedqa(
|
| 337 |
+
source=key, path=input_path, writer=writer,
|
| 338 |
+
sample_limit=sample_limit, stats=stats,
|
| 339 |
+
progress_cb=progress_cb, dedupe_seen=dedupe_seen
|
| 340 |
+
)
|
| 341 |
+
else:
|
| 342 |
+
raise ValueError(f"Unknown dataset for RAG processing: {dataset_key}")
|
| 343 |
+
|
| 344 |
+
logger.info(f"[RAG] End RAG processing dataset={dataset_key} stats={stats}")
|
| 345 |
+
return count, stats
|
utils/schema.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Centralized SFT writer (JSONL + CSV)
|
| 2 |
+
import csv
|
| 3 |
+
import orjson
|
| 4 |
+
from typing import Optional, Dict
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
# Logger
|
| 8 |
+
logger = logging.getLogger("schema")
|
| 9 |
+
if not logger.handlers:
|
| 10 |
+
logger.setLevel(logging.INFO)
|
| 11 |
+
logger.addHandler(logging.StreamHandler())
|
| 12 |
+
|
| 13 |
+
def sft_row(instruction: str, user_input: str, output: str, source: str, rid: str, task: str, meta: Optional[dict] = None):
|
| 14 |
+
return {
|
| 15 |
+
"source": source,
|
| 16 |
+
"id": rid,
|
| 17 |
+
"task": task,
|
| 18 |
+
"sft": {
|
| 19 |
+
"instruction": instruction,
|
| 20 |
+
"input": user_input,
|
| 21 |
+
"output": output
|
| 22 |
+
},
|
| 23 |
+
"meta": meta or {}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def is_valid_row(row: Dict, max_chars: int = 20000) -> bool:
|
| 27 |
+
s = row.get("sft", {})
|
| 28 |
+
instr = s.get("instruction", "")
|
| 29 |
+
inp = s.get("input", "")
|
| 30 |
+
out = s.get("output", "")
|
| 31 |
+
# basic sanity: non-empty input OR output; cap extremes
|
| 32 |
+
if not (inp or out): return False
|
| 33 |
+
if any(len(x) > max_chars for x in (instr, inp, out)): return False
|
| 34 |
+
return True
|
| 35 |
+
|
| 36 |
+
class CentralisedWriter:
|
| 37 |
+
"""Streams JSONL + CSV in parallel to stay memory-safe."""
|
| 38 |
+
def __init__(self, jsonl_path: str, csv_path: str):
|
| 39 |
+
self.jsonl_fp = open(jsonl_path, "wb")
|
| 40 |
+
self.csv_fp = open(csv_path, "w", newline="", encoding="utf-8")
|
| 41 |
+
self.csv_wr = csv.DictWriter(self.csv_fp, fieldnames=["instruction","input","output","source","id","task"])
|
| 42 |
+
self.csv_wr.writeheader()
|
| 43 |
+
|
| 44 |
+
def write(self, row: dict):
|
| 45 |
+
if not is_valid_row(row):
|
| 46 |
+
s = row.get("sft", {})
|
| 47 |
+
logger.warning(
|
| 48 |
+
f"[WRITER] Skipping invalid row id={row.get('id')} "
|
| 49 |
+
f"(len instr={len(s.get('instruction',''))}, input={len(s.get('input',''))}, output={len(s.get('output',''))})"
|
| 50 |
+
)
|
| 51 |
+
return
|
| 52 |
+
self.jsonl_fp.write(orjson.dumps(row))
|
| 53 |
+
self.jsonl_fp.write(b"\n")
|
| 54 |
+
s = row["sft"]
|
| 55 |
+
self.csv_wr.writerow({
|
| 56 |
+
"instruction": s.get("instruction",""),
|
| 57 |
+
"input": s.get("input",""),
|
| 58 |
+
"output": s.get("output",""),
|
| 59 |
+
"source": row.get("source",""),
|
| 60 |
+
"id": row.get("id",""),
|
| 61 |
+
"task": row.get("task","")
|
| 62 |
+
})
|
| 63 |
+
|
| 64 |
+
def close(self):
|
| 65 |
+
try:
|
| 66 |
+
self.jsonl_fp.close()
|
| 67 |
+
finally:
|
| 68 |
+
self.csv_fp.close()
|
utils/token.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GCS credential token refresher
|
| 2 |
+
import os, json, logging
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from google.oauth2.credentials import Credentials
|
| 5 |
+
from google_auth_oauthlib.flow import Flow
|
| 6 |
+
from google.auth.transport.requests import Request
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger("token")
|
| 9 |
+
if not logger.handlers:
|
| 10 |
+
logger.setLevel(logging.INFO)
|
| 11 |
+
handler = logging.StreamHandler()
|
| 12 |
+
logger.addHandler(handler)
|
| 13 |
+
|
| 14 |
+
SCOPES = ["https://www.googleapis.com/auth/drive.file"]
|
| 15 |
+
TOKEN_FILE = os.getenv("GDRIVE_TOKEN_FILE", "cache/secrets/gdrive_token.json")
|
| 16 |
+
|
| 17 |
+
def _load_oauth_client_web():
|
| 18 |
+
cfg_env = os.getenv("GDRIVE_CREDENTIALS_JSON")
|
| 19 |
+
if not cfg_env:
|
| 20 |
+
return None
|
| 21 |
+
try:
|
| 22 |
+
cfg = json.loads(cfg_env)
|
| 23 |
+
return cfg.get("web")
|
| 24 |
+
except Exception as e:
|
| 25 |
+
logger.error(f"❌ Failed to parse GDRIVE_CREDENTIALS_JSON: {e}")
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
def _ensure_dirs():
|
| 29 |
+
base = os.path.dirname(TOKEN_FILE)
|
| 30 |
+
if base and not os.path.exists(base):
|
| 31 |
+
os.makedirs(base, exist_ok=True)
|
| 32 |
+
|
| 33 |
+
def get_credentials() -> Optional[Credentials]:
|
| 34 |
+
# 1) Token file
|
| 35 |
+
if os.path.exists(TOKEN_FILE):
|
| 36 |
+
try:
|
| 37 |
+
with open(TOKEN_FILE, "r", encoding="utf-8") as f:
|
| 38 |
+
data = json.load(f)
|
| 39 |
+
creds = Credentials.from_authorized_user_info(data, scopes=SCOPES)
|
| 40 |
+
if creds and creds.expired and creds.refresh_token:
|
| 41 |
+
creds.refresh(Request())
|
| 42 |
+
logger.info("🔄 Refreshed access token from token file")
|
| 43 |
+
return creds
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.warning(f"⚠️ Failed to load token file: {e}")
|
| 46 |
+
|
| 47 |
+
# 2) Refresh token in env
|
| 48 |
+
refresh = os.getenv("GDRIVE_REFRESH_TOKEN")
|
| 49 |
+
web = _load_oauth_client_web()
|
| 50 |
+
if refresh and web:
|
| 51 |
+
creds = Credentials(
|
| 52 |
+
None,
|
| 53 |
+
refresh_token=refresh,
|
| 54 |
+
token_uri="https://oauth2.googleapis.com/token",
|
| 55 |
+
client_id=web.get("client_id"),
|
| 56 |
+
client_secret=web.get("client_secret"),
|
| 57 |
+
scopes=SCOPES,
|
| 58 |
+
)
|
| 59 |
+
if creds and (creds.expired or not creds.valid):
|
| 60 |
+
try:
|
| 61 |
+
creds.refresh(Request())
|
| 62 |
+
logger.info("🔄 Refreshed access token from env refresh token")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.warning(f"⚠️ Refresh with env token failed: {e}")
|
| 65 |
+
return creds
|
| 66 |
+
|
| 67 |
+
# 3) Nothing available
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
def build_auth_url(redirect_uri: str) -> str:
|
| 71 |
+
web = _load_oauth_client_web()
|
| 72 |
+
if not web:
|
| 73 |
+
raise RuntimeError("GDRIVE_CREDENTIALS_JSON missing or invalid ('web' section required)")
|
| 74 |
+
flow = Flow.from_client_config({"web": web}, scopes=SCOPES, redirect_uri=redirect_uri)
|
| 75 |
+
auth_url, _ = flow.authorization_url(
|
| 76 |
+
prompt="consent",
|
| 77 |
+
access_type="offline",
|
| 78 |
+
include_granted_scopes="true"
|
| 79 |
+
)
|
| 80 |
+
return auth_url
|
| 81 |
+
|
| 82 |
+
def exchange_code(code: str, redirect_uri: str) -> Credentials:
|
| 83 |
+
web = _load_oauth_client_web()
|
| 84 |
+
if not web:
|
| 85 |
+
raise RuntimeError("GDRIVE_CREDENTIALS_JSON missing or invalid ('web' section required)")
|
| 86 |
+
flow = Flow.from_client_config({"web": web}, scopes=SCOPES, redirect_uri=redirect_uri)
|
| 87 |
+
flow.fetch_token(code=code)
|
| 88 |
+
creds: Credentials = flow.credentials
|
| 89 |
+
|
| 90 |
+
info = {
|
| 91 |
+
"token": creds.token,
|
| 92 |
+
"refresh_token": creds.refresh_token,
|
| 93 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
| 94 |
+
"client_id": web.get("client_id"),
|
| 95 |
+
"client_secret": web.get("client_secret"),
|
| 96 |
+
"scopes": SCOPES,
|
| 97 |
+
}
|
| 98 |
+
_ensure_dirs()
|
| 99 |
+
with open(TOKEN_FILE, "w", encoding="utf-8") as f:
|
| 100 |
+
json.dump(info, f)
|
| 101 |
+
logger.info("✅ Saved Google refresh token to %s", TOKEN_FILE)
|
| 102 |
+
|
| 103 |
+
# also set env for current process
|
| 104 |
+
if creds.refresh_token:
|
| 105 |
+
os.environ["GDRIVE_REFRESH_TOKEN"] = creds.refresh_token
|
| 106 |
+
|
| 107 |
+
return creds
|