Spaces:
Sleeping
Sleeping
| # src/validation.py | |
| import pandas as pd | |
| import numpy as np | |
| from typing import Dict, List, Tuple, Optional | |
| import json | |
| import io | |
| import re | |
| from config import ( | |
| PREDICTION_FORMAT, | |
| VALIDATION_CONFIG, | |
| MODEL_CATEGORIES, | |
| EVALUATION_TRACKS, | |
| ALL_UG40_LANGUAGES, | |
| ) | |
| def detect_model_category(model_name: str, author: str, description: str) -> str: | |
| """Automatically detect model category based on name and metadata.""" | |
| # Combine all text for analysis | |
| text_to_analyze = f"{model_name} {author} {description}".lower() | |
| # Category detection patterns | |
| detection_patterns = PREDICTION_FORMAT["category_detection"] | |
| # Check for specific patterns | |
| if any(pattern in text_to_analyze for pattern in detection_patterns.get("google", [])): | |
| return "commercial" | |
| if any(pattern in text_to_analyze for pattern in detection_patterns.get("nllb", [])): | |
| return "research" | |
| if any(pattern in text_to_analyze for pattern in detection_patterns.get("m2m", [])): | |
| return "research" | |
| if any(pattern in text_to_analyze for pattern in detection_patterns.get("baseline", [])): | |
| return "baseline" | |
| # Check for research indicators | |
| research_indicators = [ | |
| "university", "research", "paper", "arxiv", "acl", "emnlp", "naacl", | |
| "transformer", "bert", "gpt", "t5", "mbart", "academic" | |
| ] | |
| if any(indicator in text_to_analyze for indicator in research_indicators): | |
| return "research" | |
| # Check for commercial indicators | |
| commercial_indicators = [ | |
| "google", "microsoft", "azure", "aws", "openai", "anthropic", | |
| "commercial", "api", "cloud", "translate" | |
| ] | |
| if any(indicator in text_to_analyze for indicator in commercial_indicators): | |
| return "commercial" | |
| # Default to community | |
| return "community" | |
| def validate_file_format(file_content: bytes, filename: str) -> Dict: | |
| """Validate file format and structure.""" | |
| try: | |
| # Determine file type | |
| if filename.endswith(".csv"): | |
| df = pd.read_csv(io.BytesIO(file_content)) | |
| elif filename.endswith(".tsv"): | |
| df = pd.read_csv(io.BytesIO(file_content), sep="\t") | |
| elif filename.endswith(".json"): | |
| data = json.loads(file_content.decode("utf-8")) | |
| df = pd.DataFrame(data) | |
| else: | |
| return { | |
| "valid": False, | |
| "error": f"Unsupported file type. Use: {', '.join(PREDICTION_FORMAT['file_types'])}", | |
| } | |
| # Check required columns | |
| missing_cols = set(PREDICTION_FORMAT["required_columns"]) - set(df.columns) | |
| if missing_cols: | |
| return { | |
| "valid": False, | |
| "error": f"Missing required columns: {', '.join(missing_cols)}", | |
| } | |
| # Basic data validation | |
| if len(df) == 0: | |
| return {"valid": False, "error": "File is empty"} | |
| # Validation checks | |
| validation_issues = [] | |
| # Check for required data | |
| if df["sample_id"].isna().any(): | |
| validation_issues.append("Missing sample_id values found") | |
| if df["prediction"].isna().any(): | |
| na_count = df["prediction"].isna().sum() | |
| validation_issues.append(f"Missing prediction values found ({na_count} empty predictions)") | |
| # Check for duplicates | |
| duplicates = df["sample_id"].duplicated() | |
| if duplicates.any(): | |
| dup_count = duplicates.sum() | |
| validation_issues.append(f"Duplicate sample_id values found ({dup_count} duplicates)") | |
| # Data type validation | |
| if not df["sample_id"].dtype == "object": | |
| df["sample_id"] = df["sample_id"].astype(str) | |
| # Check sample_id format | |
| invalid_ids = ~df["sample_id"].str.match(r"salt_\d{6}", na=False) | |
| if invalid_ids.any(): | |
| invalid_count = invalid_ids.sum() | |
| validation_issues.append(f"Invalid sample_id format found ({invalid_count} invalid IDs)") | |
| # Return results | |
| if validation_issues: | |
| return { | |
| "valid": False, | |
| "error": "; ".join(validation_issues), | |
| "dataframe": df, | |
| "row_count": len(df), | |
| "columns": list(df.columns), | |
| } | |
| return { | |
| "valid": True, | |
| "dataframe": df, | |
| "row_count": len(df), | |
| "columns": list(df.columns), | |
| } | |
| except Exception as e: | |
| return {"valid": False, "error": f"Error parsing file: {str(e)}"} | |
| def validate_predictions_content(predictions: pd.DataFrame) -> Dict: | |
| """Validate prediction content quality.""" | |
| issues = [] | |
| warnings = [] | |
| quality_metrics = {} | |
| # Basic content checks | |
| empty_predictions = predictions["prediction"].str.strip().eq("").sum() | |
| if empty_predictions > 0: | |
| issues.append(f"{empty_predictions} empty predictions found") | |
| # Length analysis | |
| pred_lengths = predictions["prediction"].str.len() | |
| quality_metrics["avg_length"] = float(pred_lengths.mean()) | |
| quality_metrics["std_length"] = float(pred_lengths.std()) | |
| # Check for suspiciously short predictions | |
| short_predictions = (pred_lengths < 3).sum() | |
| if short_predictions > len(predictions) * 0.05: # More than 5% | |
| issues.append(f"{short_predictions} very short predictions (< 3 characters)") | |
| # Check for suspiciously long predictions | |
| long_predictions = (pred_lengths > 500).sum() | |
| if long_predictions > len(predictions) * 0.01: # More than 1% | |
| warnings.append(f"{long_predictions} very long predictions (> 500 characters)") | |
| # Check for repeated predictions | |
| duplicate_predictions = predictions["prediction"].duplicated().sum() | |
| duplicate_rate = duplicate_predictions / len(predictions) | |
| quality_metrics["duplicate_rate"] = float(duplicate_rate) | |
| if duplicate_rate > VALIDATION_CONFIG["quality_thresholds"]["max_duplicate_rate"]: | |
| issues.append(f"{duplicate_predictions} duplicate prediction texts ({duplicate_rate:.1%})") | |
| # Check for placeholder text | |
| placeholder_patterns = [ | |
| r"^(test|placeholder|todo|xxx|aaa|bbb)$", | |
| r"^[a-z]{1,3}$", # Very short gibberish | |
| r"^\d+$", # Just numbers | |
| r"^[^\w\s]*$", # Only punctuation | |
| ] | |
| placeholder_count = 0 | |
| for pattern in placeholder_patterns: | |
| placeholder_matches = predictions["prediction"].str.match(pattern, flags=re.IGNORECASE, na=False).sum() | |
| placeholder_count += placeholder_matches | |
| if placeholder_count > len(predictions) * 0.02: # More than 2% | |
| issues.append(f"{placeholder_count} placeholder-like predictions detected") | |
| # Calculate overall quality score | |
| quality_score = 1.0 | |
| quality_score -= len(issues) * 0.3 # Major penalty for issues | |
| quality_score -= len(warnings) * 0.1 # Minor penalty for warnings | |
| quality_score -= max(0, duplicate_rate - 0.05) * 2 # Penalty for excessive duplicates | |
| # Length appropriateness | |
| if quality_metrics["avg_length"] < VALIDATION_CONFIG["quality_thresholds"]["min_avg_length"]: | |
| quality_score -= 0.2 | |
| elif quality_metrics["avg_length"] > VALIDATION_CONFIG["quality_thresholds"]["max_avg_length"]: | |
| quality_score -= 0.1 | |
| quality_score = max(0.0, min(1.0, quality_score)) | |
| return { | |
| "has_issues": len(issues) > 0, | |
| "issues": issues, | |
| "warnings": warnings, | |
| "quality_score": quality_score, | |
| "quality_metrics": quality_metrics, | |
| } | |
| def validate_against_test_set( | |
| predictions: pd.DataFrame, test_set: pd.DataFrame | |
| ) -> Dict: | |
| """Validate predictions against test set.""" | |
| # Convert IDs to string for comparison | |
| pred_ids = set(predictions["sample_id"].astype(str)) | |
| test_ids = set(test_set["sample_id"].astype(str)) | |
| # Check overall coverage | |
| missing_ids = test_ids - pred_ids | |
| extra_ids = pred_ids - test_ids | |
| matching_ids = pred_ids & test_ids | |
| overall_coverage = len(matching_ids) / len(test_ids) | |
| # Track-specific coverage analysis | |
| track_coverage = {} | |
| for track_name, track_config in EVALUATION_TRACKS.items(): | |
| track_languages = track_config["languages"] | |
| # Filter test set to track languages | |
| track_test_set = test_set[ | |
| (test_set["source_language"].isin(track_languages)) & | |
| (test_set["target_language"].isin(track_languages)) | |
| ] | |
| if len(track_test_set) == 0: | |
| continue | |
| track_test_ids = set(track_test_set["sample_id"].astype(str)) | |
| track_matching_ids = pred_ids & track_test_ids | |
| track_coverage[track_name] = { | |
| "total_samples": len(track_test_set), | |
| "covered_samples": len(track_matching_ids), | |
| "coverage_rate": len(track_matching_ids) / len(track_test_set), | |
| "meets_minimum": len(track_matching_ids) >= VALIDATION_CONFIG["min_samples_per_track"][track_name], | |
| "min_required": VALIDATION_CONFIG["min_samples_per_track"][track_name], | |
| } | |
| # Missing rate validation | |
| missing_rate = len(missing_ids) / len(test_ids) | |
| meets_missing_threshold = missing_rate <= VALIDATION_CONFIG["max_missing_rate"] | |
| return { | |
| "overall_coverage": overall_coverage, | |
| "missing_count": len(missing_ids), | |
| "extra_count": len(extra_ids), | |
| "matching_count": len(matching_ids), | |
| "missing_rate": missing_rate, | |
| "meets_missing_threshold": meets_missing_threshold, | |
| "is_complete": overall_coverage == 1.0, | |
| "track_coverage": track_coverage, | |
| "missing_ids_sample": list(missing_ids)[:10], | |
| "extra_ids_sample": list(extra_ids)[:10], | |
| } | |
| def generate_validation_report( | |
| format_result: Dict, | |
| content_result: Dict, | |
| test_set_result: Dict, | |
| model_name: str = "", | |
| detected_category: str = "community", | |
| ) -> str: | |
| """Generate comprehensive validation report.""" | |
| report = [] | |
| # Header | |
| report.append(f"### π¬ Validation Report: {model_name or 'Submission'}") | |
| report.append("") | |
| # Model categorization | |
| category_info = MODEL_CATEGORIES.get(detected_category, MODEL_CATEGORIES["community"]) | |
| report.append(f"**Detected Model Category**: {category_info['name']}") | |
| report.append("") | |
| # File format validation | |
| if format_result["valid"]: | |
| report.append("β **File Format**: Valid") | |
| report.append(f" - Rows: {format_result['row_count']:,}") | |
| report.append(f" - Columns: {', '.join(format_result['columns'])}") | |
| else: | |
| report.append("β **File Format**: Invalid") | |
| report.append(f" - Error: {format_result['error']}") | |
| return "\n".join(report) | |
| # Content quality validation | |
| quality_score = content_result.get("quality_score", 0.0) | |
| if content_result["has_issues"]: | |
| report.append("β **Content Quality**: Issues Found") | |
| for issue in content_result["issues"]: | |
| report.append(f" - β {issue}") | |
| else: | |
| report.append("β **Content Quality**: Good") | |
| if content_result["warnings"]: | |
| for warning in content_result["warnings"]: | |
| report.append(f" - β οΈ {warning}") | |
| report.append(f" - **Quality Score**: {quality_score:.2f}/1.00") | |
| report.append("") | |
| # Test set coverage validation | |
| overall_coverage = test_set_result["overall_coverage"] | |
| meets_threshold = test_set_result["meets_missing_threshold"] | |
| if overall_coverage == 1.0: | |
| report.append("β **Test Set Coverage**: Complete") | |
| elif overall_coverage >= 0.95 and meets_threshold: | |
| report.append("β **Test Set Coverage**: Adequate") | |
| else: | |
| report.append("β **Test Set Coverage**: Insufficient") | |
| report.append(f" - Coverage: {overall_coverage:.1%} ({test_set_result['matching_count']:,} / {test_set_result['matching_count'] + test_set_result['missing_count']:,})") | |
| report.append(f" - Missing Rate: {test_set_result['missing_rate']:.1%}") | |
| report.append("") | |
| # Track-specific coverage analysis | |
| report.append("#### π Track-Specific Analysis") | |
| track_coverage = test_set_result.get("track_coverage", {}) | |
| for track_name, coverage_info in track_coverage.items(): | |
| track_config = EVALUATION_TRACKS[track_name] | |
| status = "β " if coverage_info["meets_minimum"] else "β" | |
| report.append(f"**{status} {track_config['name']}**:") | |
| report.append(f" - **Samples**: {coverage_info['covered_samples']:,} / {coverage_info['total_samples']:,}") | |
| report.append(f" - **Coverage**: {coverage_info['coverage_rate']:.1%}") | |
| report.append(f" - **Minimum Required**: {coverage_info['min_required']:,}") | |
| report.append(f" - **Status**: {'Adequate' if coverage_info['meets_minimum'] else 'Insufficient'}") | |
| report.append("") | |
| # Final verdict | |
| all_checks_pass = ( | |
| format_result["valid"] and | |
| not content_result["has_issues"] and | |
| overall_coverage >= 0.95 and | |
| meets_threshold | |
| ) | |
| can_evaluate_with_limits = ( | |
| format_result["valid"] and | |
| overall_coverage >= 0.8 and | |
| not any("β" in issue for issue in content_result.get("issues", [])) | |
| ) | |
| if all_checks_pass: | |
| report.append("π **Final Verdict**: Ready for evaluation!") | |
| elif can_evaluate_with_limits: | |
| report.append("β οΈ **Final Verdict**: Can be evaluated with limitations") | |
| report.append(" - Results will include notes about limitations") | |
| else: | |
| report.append("β **Final Verdict**: Please address critical issues before submission") | |
| return "\n".join(report) | |
| def validate_submission( | |
| file_content: bytes, | |
| filename: str, | |
| test_set: pd.DataFrame, | |
| model_name: str = "", | |
| author: str = "", | |
| description: str = "" | |
| ) -> Dict: | |
| """Complete validation pipeline for submissions.""" | |
| # Step 1: Detect model category | |
| detected_category = detect_model_category(model_name, author, description) | |
| # Step 2: File format validation | |
| format_result = validate_file_format(file_content, filename) | |
| if not format_result["valid"]: | |
| return { | |
| "valid": False, | |
| "can_evaluate": False, | |
| "category": detected_category, | |
| "report": generate_validation_report( | |
| format_result, {}, {}, model_name, detected_category | |
| ), | |
| "predictions": None, | |
| } | |
| predictions = format_result["dataframe"] | |
| # Step 3: Content validation | |
| content_result = validate_predictions_content(predictions) | |
| # Step 4: Test set validation | |
| test_set_result = validate_against_test_set(predictions, test_set) | |
| # Step 5: Generate report | |
| report = generate_validation_report( | |
| format_result, content_result, test_set_result, model_name, detected_category | |
| ) | |
| # Overall validity determination | |
| is_valid = ( | |
| format_result["valid"] and | |
| not content_result["has_issues"] and | |
| test_set_result["overall_coverage"] >= 0.95 and | |
| test_set_result["meets_missing_threshold"] | |
| ) | |
| # Evaluation eligibility (more permissive) | |
| can_evaluate = ( | |
| format_result["valid"] and | |
| test_set_result["overall_coverage"] >= 0.8 and | |
| not any("β" in issue for issue in content_result.get("issues", [])) | |
| ) | |
| return { | |
| "valid": is_valid, | |
| "can_evaluate": can_evaluate, | |
| "category": detected_category, | |
| "coverage": test_set_result["overall_coverage"], | |
| "report": report, | |
| "predictions": predictions, | |
| "quality_score": content_result.get("quality_score", 0.8), | |
| "track_coverage": test_set_result.get("track_coverage", {}), | |
| } |