File size: 8,645 Bytes
c293f7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#!/usr/bin/env python3
"""
Training Script for Fine-tuning IndicBERT on Misinformation Detection
Combines the data loader and enhanced IndicBERT processor for end-to-end training.
"""

import sys
import os
import logging
from pathlib import Path

# Add backend to path
sys.path.insert(0, str(Path(__file__).parent))

from data_loader import DataLoader
from enhanced_indicbert_processor import EnhancedIndicBERTProcessor
from advanced_ml_classifier import create_comprehensive_training_data
import pandas as pd

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def train_indicbert_model(
    use_existing_dataset: bool = False,
    dataset_filename: str = "indian_misinformation.csv",
    epochs: int = 3,
    batch_size: int = 16,
    learning_rate: float = 2e-5,
    test_size: float = 0.2,
    val_size: float = 0.1,
    output_dir: str = "models/finetuned_indicbert",
    apply_quantization: bool = False
):
    """
    Complete training pipeline for IndicBERT fine-tuning
    
    Args:
        use_existing_dataset: Whether to use existing dataset file or create new one
        dataset_filename: Name of dataset file in datasets/ directory
        epochs: Number of training epochs
        batch_size: Batch size for training
        learning_rate: Learning rate for optimizer
        test_size: Fraction of data for test set
        val_size: Fraction of data for validation set
        output_dir: Directory to save fine-tuned model
        apply_quantization: Whether to apply quantization after training
    """
    
    logger.info("๐Ÿš€ Starting IndicBERT Fine-tuning Pipeline")
    logger.info("=" * 70)
    
    # Step 1: Load Data
    logger.info("\n๐Ÿ“Š Step 1: Loading Data")
    data_loader = DataLoader(datasets_dir='datasets')
    
    if use_existing_dataset and os.path.exists(f'datasets/{dataset_filename}'):
        logger.info(f"Loading existing dataset: {dataset_filename}")
        df = data_loader.load_csv(dataset_filename)
    else:
        logger.info("Creating new Indian context dataset...")
        # Combine synthetic dataset from advanced_ml_classifier
        synthetic_df = create_comprehensive_training_data()
        
        # Add Indian-specific examples
        indian_df = data_loader.create_indian_context_dataset()
        
        # Combine datasets
        df = pd.concat([synthetic_df, indian_df], ignore_index=True)
        
        # Save combined dataset
        data_loader.save_processed_dataset(df, "combined_training_data.csv")
        logger.info(f"Created and saved combined dataset with {len(df)} examples")
    
    # Step 2: Preprocess Data
    logger.info("\n๐Ÿ”ง Step 2: Preprocessing Data")
    df = data_loader.preprocess(
        df,
        clean_text=True,
        remove_duplicates=True,
        balance_classes=True  # Balance for better training
    )
    
    # Step 3: Split Data
    logger.info("\nโœ‚๏ธ Step 3: Splitting Data")
    train_df, val_df, test_df = data_loader.split_data(
        df,
        test_size=test_size,
        val_size=val_size
    )
    
    # Step 4: Initialize IndicBERT Processor
    logger.info("\n๐Ÿง  Step 4: Initializing IndicBERT Processor")
    processor = EnhancedIndicBERTProcessor(
        model_name="ai4bharat/indic-bert",
        num_labels=2
    )
    
    # Step 5: Fine-tune Model
    logger.info("\n๐ŸŽ“ Step 5: Fine-tuning IndicBERT")
    training_results = processor.fine_tune(
        train_texts=train_df['text'].tolist(),
        train_labels=train_df['label'].tolist(),
        val_texts=val_df['text'].tolist() if len(val_df) > 0 else None,
        val_labels=val_df['label'].tolist() if len(val_df) > 0 else None,
        epochs=epochs,
        batch_size=batch_size,
        learning_rate=learning_rate,
        output_dir=output_dir
    )
    
    logger.info(f"\n๐Ÿ“Š Training Results:")
    logger.info(f"  Final Training Accuracy: {training_results['final_train_accuracy']:.4f}")
    if training_results['best_val_accuracy']:
        logger.info(f"  Best Validation Accuracy: {training_results['best_val_accuracy']:.4f}")
    
    # Step 6: Evaluate on Test Set
    if len(test_df) > 0:
        logger.info("\n๐Ÿ“ˆ Step 6: Evaluating on Test Set")
        test_predictions = processor.predict_batch(test_df['text'].tolist())
        
        # Calculate test accuracy
        correct = sum(
            1 for i, pred in enumerate(test_predictions)
            if (pred['prediction'] == 'fake' and test_df.iloc[i]['label'] == 1) or
               (pred['prediction'] == 'real' and test_df.iloc[i]['label'] == 0)
        )
        test_accuracy = correct / len(test_df)
        
        logger.info(f"  Test Set Accuracy: {test_accuracy:.4f}")
        
        # Show some sample predictions
        logger.info("\n๐Ÿ” Sample Predictions:")
        for i in range(min(5, len(test_df))):
            text = test_df.iloc[i]['text']
            true_label = 'fake' if test_df.iloc[i]['label'] == 1 else 'real'
            pred = test_predictions[i]
            
            logger.info(f"\n  Text: {text[:80]}...")
            logger.info(f"  True: {true_label}, Predicted: {pred['prediction']}, Confidence: {pred['confidence']:.4f}")
    
    # Step 7: Apply Quantization (Optional)
    if apply_quantization:
        logger.info("\nโšก Step 7: Applying Quantization")
        processor.quantize_model()
        # Save quantized model
        quantized_output_dir = output_dir + "_quantized"
        processor.save_model(quantized_output_dir)
        logger.info(f"Quantized model saved to {quantized_output_dir}")
    
    logger.info("\n" + "=" * 70)
    logger.info("โœ… Training Pipeline Completed Successfully!")
    logger.info(f"๐Ÿ“ Model saved to: {output_dir}")
    
    return processor, training_results


def test_trained_model(model_dir: str = "models/finetuned_indicbert"):
    """Test the trained model with sample inputs"""
    
    logger.info("๐Ÿงช Testing Trained Model")
    logger.info("=" * 70)
    
    # Load model
    processor = EnhancedIndicBERTProcessor()
    processor.load_model(model_dir)
    
    # Test samples
    test_samples = [
        "BREAKING: Modi government secretly selling India to China, conspiracy exposed!",
        "Government of India announces new infrastructure development plan",
        "EXPOSED: Vaccines contain microchips to control population",
        "According to Ministry of Health, COVID-19 cases declining nationwide",
        "URGENT: This miracle cure will shock you - doctors hate it!",
        "Supreme Court delivers verdict on constitutional matter",
    ]
    
    logger.info("\n๐Ÿ“ Test Predictions:\n")
    
    predictions = processor.predict_batch(test_samples)
    
    for text, pred in zip(test_samples, predictions):
        logger.info(f"Text: {text}")
        logger.info(f"Prediction: {pred['prediction'].upper()} (confidence: {pred['confidence']:.4f})")
        logger.info(f"Probabilities: Real={pred['probabilities']['real']:.4f}, Fake={pred['probabilities']['fake']:.4f}")
        logger.info("-" * 70)
    
    logger.info("\nโœ… Testing Complete!")


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Train IndicBERT for misinformation detection")
    parser.add_argument("--mode", type=str, default="train", choices=["train", "test"],
                       help="Mode: train or test")
    parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
    parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
    parser.add_argument("--output_dir", type=str, default="models/finetuned_indicbert",
                       help="Output directory for model")
    parser.add_argument("--quantize", action="store_true", help="Apply quantization after training")
    parser.add_argument("--use_existing", action="store_true", 
                       help="Use existing dataset file")
    parser.add_argument("--dataset", type=str, default="indian_misinformation.csv",
                       help="Dataset filename")
    
    args = parser.parse_args()
    
    if args.mode == "train":
        processor, results = train_indicbert_model(
            use_existing_dataset=args.use_existing,
            dataset_filename=args.dataset,
            epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.learning_rate,
            output_dir=args.output_dir,
            apply_quantization=args.quantize
        )
    else:
        test_trained_model(model_dir=args.output_dir)