File size: 11,719 Bytes
31f0e50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
#!/usr/bin/env python
"""

Dataset Validation Script for ScamShield AI.



Validates the scam detection training dataset against:

- DATA_SPEC.md schema requirements

- Task 4.1 acceptance criteria:

  - 1000+ total samples

  - 60% scam, 40% legitimate

  - 50% English, 40% Hindi, 10% Hinglish

  - All samples validated



Output: Validation report with pass/fail status.

"""

import json
import os
import sys
from collections import Counter
from typing import Dict, List, Tuple

# Ensure we can import from app
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Required schema fields
REQUIRED_FIELDS = ["id", "message", "language", "label", "confidence", "scam_type", "indicators", "metadata"]
REQUIRED_METADATA_FIELDS = ["source", "annotator", "annotation_date", "difficulty"]

# Valid values
VALID_LANGUAGES = {"en", "hi", "hinglish"}
VALID_LABELS = {"scam", "legitimate"}
VALID_SOURCES = {"synthetic", "real", "curated"}
VALID_ANNOTATORS = {"human", "ai"}
VALID_DIFFICULTIES = {"easy", "medium", "hard"}
VALID_SCAM_TYPES = {
    "lottery", "bank_fraud", "police_threat", "phishing",
    "courier_fraud", "government_impersonation", "upi_fraud", None
}


def load_dataset(filepath: str) -> List[Dict]:
    """Load dataset from JSONL file."""
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"Dataset not found: {filepath}")
    
    samples = []
    with open(filepath, "r", encoding="utf-8") as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                sample = json.loads(line)
                samples.append(sample)
            except json.JSONDecodeError as e:
                print(f"  [ERROR] Line {line_num}: Invalid JSON - {e}")
    
    return samples


def validate_sample_schema(sample: Dict, sample_idx: int) -> List[str]:
    """Validate a single sample against the schema."""
    errors = []
    sample_id = sample.get("id", f"sample_{sample_idx}")
    
    # Check required fields
    for field in REQUIRED_FIELDS:
        if field not in sample:
            errors.append(f"{sample_id}: Missing required field '{field}'")
    
    # Check metadata fields
    if "metadata" in sample and isinstance(sample["metadata"], dict):
        for field in REQUIRED_METADATA_FIELDS:
            if field not in sample["metadata"]:
                errors.append(f"{sample_id}: Missing metadata field '{field}'")
    elif "metadata" in sample:
        errors.append(f"{sample_id}: 'metadata' must be an object")
    
    # Validate field types and values
    if "message" in sample:
        if not isinstance(sample["message"], str):
            errors.append(f"{sample_id}: 'message' must be a string")
        elif len(sample["message"]) == 0:
            errors.append(f"{sample_id}: 'message' cannot be empty")
        elif len(sample["message"]) > 5000:
            errors.append(f"{sample_id}: 'message' exceeds 5000 chars")
    
    if "language" in sample:
        if sample["language"] not in VALID_LANGUAGES:
            errors.append(f"{sample_id}: Invalid language '{sample['language']}'")
    
    if "label" in sample:
        if sample["label"] not in VALID_LABELS:
            errors.append(f"{sample_id}: Invalid label '{sample['label']}'")
    
    if "confidence" in sample:
        if not isinstance(sample["confidence"], (int, float)):
            errors.append(f"{sample_id}: 'confidence' must be a number")
        elif not (0.0 <= sample["confidence"] <= 1.0):
            errors.append(f"{sample_id}: 'confidence' must be between 0 and 1")
    
    if "scam_type" in sample:
        if sample["scam_type"] is not None and sample["scam_type"] not in VALID_SCAM_TYPES:
            errors.append(f"{sample_id}: Invalid scam_type '{sample['scam_type']}'")
        if sample.get("label") == "scam" and sample["scam_type"] is None:
            pass  # Allow null scam_type for scams (some edge cases)
        if sample.get("label") == "legitimate" and sample["scam_type"] is not None:
            errors.append(f"{sample_id}: Legitimate message should have null scam_type")
    
    if "indicators" in sample:
        if not isinstance(sample["indicators"], list):
            errors.append(f"{sample_id}: 'indicators' must be an array")
    
    # Validate metadata values
    if "metadata" in sample and isinstance(sample["metadata"], dict):
        meta = sample["metadata"]
        if meta.get("source") not in VALID_SOURCES:
            errors.append(f"{sample_id}: Invalid source '{meta.get('source')}'")
        if meta.get("annotator") not in VALID_ANNOTATORS:
            errors.append(f"{sample_id}: Invalid annotator '{meta.get('annotator')}'")
        if meta.get("difficulty") not in VALID_DIFFICULTIES:
            errors.append(f"{sample_id}: Invalid difficulty '{meta.get('difficulty')}'")
    
    return errors


def validate_dataset(samples: List[Dict]) -> Dict:
    """Validate complete dataset."""
    results = {
        "total_samples": len(samples),
        "schema_errors": [],
        "unique_ids": set(),
        "duplicate_ids": [],
        "label_counts": Counter(),
        "language_counts": Counter(),
        "scam_type_counts": Counter(),
        "difficulty_counts": Counter(),
    }
    
    # Validate each sample
    for idx, sample in enumerate(samples):
        # Schema validation
        errors = validate_sample_schema(sample, idx)
        results["schema_errors"].extend(errors)
        
        # Check for duplicate IDs
        sample_id = sample.get("id")
        if sample_id:
            if sample_id in results["unique_ids"]:
                results["duplicate_ids"].append(sample_id)
            else:
                results["unique_ids"].add(sample_id)
        
        # Count distributions
        results["label_counts"][sample.get("label")] += 1
        results["language_counts"][sample.get("language")] += 1
        results["scam_type_counts"][sample.get("scam_type")] += 1
        if "metadata" in sample and isinstance(sample["metadata"], dict):
            results["difficulty_counts"][sample["metadata"].get("difficulty")] += 1
    
    # Calculate ratios
    total = results["total_samples"]
    if total > 0:
        results["scam_ratio"] = results["label_counts"].get("scam", 0) / total
        results["legit_ratio"] = results["label_counts"].get("legitimate", 0) / total
        results["en_ratio"] = results["language_counts"].get("en", 0) / total
        results["hi_ratio"] = results["language_counts"].get("hi", 0) / total
        results["hinglish_ratio"] = results["language_counts"].get("hinglish", 0) / total
    else:
        results["scam_ratio"] = 0
        results["legit_ratio"] = 0
        results["en_ratio"] = 0
        results["hi_ratio"] = 0
        results["hinglish_ratio"] = 0
    
    # Check acceptance criteria
    results["criteria"] = {
        "total_1000": {
            "pass": total >= 1000,
            "value": total,
            "required": ">=1000",
        },
        "scam_60": {
            "pass": 0.55 <= results["scam_ratio"] <= 0.65,
            "value": f"{results['scam_ratio']:.1%}",
            "required": "55-65%",
        },
        "en_50": {
            "pass": 0.45 <= results["en_ratio"] <= 0.55,
            "value": f"{results['en_ratio']:.1%}",
            "required": "45-55%",
        },
        "hi_40": {
            "pass": 0.35 <= results["hi_ratio"] <= 0.45,
            "value": f"{results['hi_ratio']:.1%}",
            "required": "35-45%",
        },
        "hinglish_10": {
            "pass": 0.05 <= results["hinglish_ratio"] <= 0.15,
            "value": f"{results['hinglish_ratio']:.1%}",
            "required": "5-15%",
        },
        "no_schema_errors": {
            "pass": len(results["schema_errors"]) == 0,
            "value": len(results["schema_errors"]),
            "required": "0 errors",
        },
        "no_duplicate_ids": {
            "pass": len(results["duplicate_ids"]) == 0,
            "value": len(results["duplicate_ids"]),
            "required": "0 duplicates",
        },
    }
    
    results["all_pass"] = all(c["pass"] for c in results["criteria"].values())
    
    return results


def print_report(results: Dict) -> None:
    """Print validation report."""
    print("\n" + "=" * 60)
    print("Dataset Validation Report")
    print("=" * 60)
    
    print(f"\n--- Summary ---")
    print(f"Total samples: {results['total_samples']}")
    print(f"Unique IDs: {len(results['unique_ids'])}")
    print(f"Schema errors: {len(results['schema_errors'])}")
    print(f"Duplicate IDs: {len(results['duplicate_ids'])}")
    
    print(f"\n--- Label Distribution ---")
    for label, count in sorted(results["label_counts"].items()):
        ratio = count / results["total_samples"] if results["total_samples"] > 0 else 0
        print(f"  {label}: {count} ({ratio:.1%})")
    
    print(f"\n--- Language Distribution ---")
    for lang, count in sorted(results["language_counts"].items()):
        ratio = count / results["total_samples"] if results["total_samples"] > 0 else 0
        print(f"  {lang}: {count} ({ratio:.1%})")
    
    print(f"\n--- Scam Type Distribution ---")
    for scam_type, count in sorted(results["scam_type_counts"].items(), key=lambda x: (x[0] is None, x[0])):
        print(f"  {scam_type or 'None'}: {count}")
    
    print(f"\n--- Difficulty Distribution ---")
    for diff, count in sorted(results["difficulty_counts"].items()):
        print(f"  {diff}: {count}")
    
    print(f"\n--- Acceptance Criteria ---")
    for name, criterion in results["criteria"].items():
        status = "PASS" if criterion["pass"] else "FAIL"
        print(f"  {name}: {criterion['value']} (required: {criterion['required']}) - {status}")
    
    if results["schema_errors"]:
        print(f"\n--- Schema Errors (first 10) ---")
        for error in results["schema_errors"][:10]:
            print(f"  - {error}")
        if len(results["schema_errors"]) > 10:
            print(f"  ... and {len(results['schema_errors']) - 10} more")
    
    if results["duplicate_ids"]:
        print(f"\n--- Duplicate IDs ---")
        for dup_id in results["duplicate_ids"][:10]:
            print(f"  - {dup_id}")
    
    print(f"\n--- Final Result ---")
    if results["all_pass"]:
        print("ALL ACCEPTANCE CRITERIA PASSED")
    else:
        print("SOME ACCEPTANCE CRITERIA FAILED")


def main():
    """Main entry point."""
    print("=" * 60)
    print("ScamShield AI - Dataset Validation")
    print("=" * 60)
    
    # Determine dataset path
    dataset_path = os.path.join(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
        "data",
        "scam_detection_train.jsonl"
    )
    
    if len(sys.argv) > 1:
        dataset_path = sys.argv[1]
    
    print(f"\nValidating: {dataset_path}")
    
    # Load dataset
    try:
        samples = load_dataset(dataset_path)
        print(f"Loaded {len(samples)} samples")
    except FileNotFoundError as e:
        print(f"[ERROR] {e}")
        return 1
    except Exception as e:
        print(f"[ERROR] Failed to load dataset: {e}")
        return 1
    
    # Validate
    results = validate_dataset(samples)
    
    # Print report
    print_report(results)
    
    return 0 if results["all_pass"] else 1


if __name__ == "__main__":
    sys.exit(main())