Spaces:
Runtime error
Runtime error
| import json | |
| import traceback | |
| from fastapi import HTTPException, UploadFile | |
| from typing import Optional, List, Dict, Any, Tuple, Union | |
| import uuid | |
| from datetime import datetime, timezone | |
| import pandas as pd | |
| from app.schemas.prediction import BatchPredictionResponse | |
| from app.schemas.validation import ValidationResponse | |
| from app.services.auth import UserService | |
| from app.core.enums import * | |
| from app.core.constant import * | |
| from app.utils.s3_utils import * | |
| from app.services.models import async_model_crud | |
| from app.utils.dataset_utils import convert_firco_json_to_dataframe, convert_non_firco_json_to_dataframe | |
| import io | |
| from app.core.enums import ModelName | |
| from app.ml.models.firco_ensemble_model import FircoEnsembleXGBModel | |
| from app.ml.models.firco_tfidf_model import FircoTFIDFXGBModel | |
| from app.ml.models.non_firco_roberta_model import NonFircoRobertaModel | |
| from app.services.user_and_model import async_user_and_model_crud | |
| user_service = UserService() | |
| def validate_level(level: str): | |
| """Validate that the level is one of the allowed values.""" | |
| if level not in ("hit", "message", "both"): | |
| raise HTTPException(status_code=400, detail="level must be 'hit', 'message', or 'both'") | |
| def format_prediction_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]: | |
| """Format a list of prediction runs for API response.""" | |
| return {"prediction_runs": [run.model_dump() for run in runs]} | |
| def format_single_run_response(run: Any, entity_name: str) -> Dict[str, Any]: | |
| """Format a single run (training/validation/prediction) for API response.""" | |
| if not run: | |
| raise HTTPException(status_code=404, detail=f"{entity_name} not found") | |
| return run.model_dump() | |
| ## To be changed later (should use version instead of training) | |
| async def get_version_from_model_id(model_id: str, async_model_crud) -> str: | |
| from app.services.train import TrainingService | |
| training_service = TrainingService() | |
| """Helper function to get training_id from model_id.""" | |
| model_record = await async_model_crud.get_by_model_id(model_id) | |
| if not model_record: | |
| raise HTTPException(status_code=404, detail=f"Model with model_id '{model_id}' not found") | |
| training_runs = await training_service.list_by_model(model_id) | |
| if not training_runs: | |
| raise HTTPException(status_code=404, detail=f"No training runs found for model_id '{model_id}'") | |
| return training_runs[0].training_id # Latest training run | |
| def get_model_files_for_level(level: str, training_id: str) -> List[tuple]: | |
| """Generate list of (s3_filename, zip_filename) for a given level.""" | |
| return [ | |
| (f"xgb_model_{level}_{training_id}.pkl", f"{level}_model.pkl"), | |
| (f"feature_encoder_{level}_{training_id}.pkl", f"{level}_feature_encoder.pkl"), | |
| (f"label_encoder_{level}_{training_id}.pkl", f"{level}_label_encoder.pkl") | |
| ] | |
| def get_download_files_and_name(level: str, training_id: str, model_id: str) -> tuple: | |
| """Get files to download and zip filename based on level.""" | |
| if level == "both": | |
| files_to_download = (get_model_files_for_level("hit", training_id) + | |
| get_model_files_for_level("message", training_id)) | |
| zip_filename = f"model_{model_id}_both_{training_id}.zip" | |
| else: | |
| files_to_download = get_model_files_for_level(level, training_id) | |
| zip_filename = f"model_{model_id}_{level}_{training_id}.zip" | |
| return files_to_download, zip_filename | |
| # Additional helper functions for model_utils.py duplication reduction | |
| def validate_model_name_from_string(model_name_str: str): | |
| """Convert and validate model type string to enum.""" | |
| try: | |
| return ModelName(model_name_str) | |
| except ValueError: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid model_name: {model_name_str}. Supported types: {[e.value for e in ModelName]}" | |
| ) | |
| def validate_dataframe_not_empty(df, context: str = "data"): | |
| """Validate that dataframe is not empty.""" | |
| if df.empty or len(df) == 0: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"The {context} is empty or contains no valid data after processing" | |
| ) | |
| def validate_prediction_csv_params(file: UploadFile, level: str, training_id: str): | |
| """Validate all CSV prediction parameters at once.""" | |
| validate_required_param(training_id, "training_id") | |
| validate_level(level) | |
| validate_required_param(file, "file") | |
| validate_csv_file(file) | |
| def validate_prediction_json_params(training_id: str, level: str): | |
| """Validate all JSON prediction parameters at once.""" | |
| validate_required_param(training_id, "training_id") | |
| validate_level(level) | |
| def generate_operation_id() -> str: | |
| """Generate a unique ID for operations (training, validation, prediction).""" | |
| return datetime.now().strftime("%Y%m%d_%H%M%S") + '_' + str(uuid.uuid4())[:12] | |
| async def create_database_record(crud_instance, record_data, record_type: str): | |
| """Generic database record creation with error handling.""" | |
| try: | |
| record = await crud_instance.create(record_data) | |
| print(f"[DB] {record_type} record created: {record.id}") | |
| return record | |
| except Exception as e: | |
| print(f"[DB] Error creating {record_type} record: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") | |
| # TO DO apply logic to get ensemble | |
| def get_model_instances(training_id=None, model_name = None, targets = None) -> Dict[str, Any]: | |
| """Get model instances based on level (hit, message, or both) and model type.""" | |
| if model_name == ModelName.FIRCO_ENSEMBLE_XGB: | |
| ModelClass = FircoEnsembleXGBModel | |
| elif model_name == ModelName.FIRCO_TFIDF_XGB: | |
| ModelClass = FircoTFIDFXGBModel | |
| elif model_name == ModelName.ROBERTA_COMPLIANCE: | |
| ModelClass = NonFircoRobertaModel | |
| # print("[DEBUG] ModelClass selected is ", ModelClass) | |
| return ModelClass(training_id=training_id, targets=targets) | |
| def combine_both_level_results(hit_results: Dict, message_results: Dict, operation: str) -> Dict: | |
| """Combine results from both hit and message level operations.""" | |
| if operation == "validation": | |
| combined_accuracy = (hit_results["accuracy"] + message_results["accuracy"]) / 2 | |
| return { | |
| "accuracy": combined_accuracy, | |
| "metrics": { | |
| "hit_level": hit_results["classification_report"], | |
| "message_level": message_results["classification_report"] | |
| }, | |
| "data_size": {"total_records": hit_results["total_samples"]}, | |
| "hit_results": hit_results, | |
| "message_results": message_results | |
| } | |
| elif operation == "training": | |
| print(f"[DEBUG] Combining training results for both levels resultant artifacts is { {**hit_results['artifacts']['s3_keys'], **message_results['artifacts']['s3_keys']} }") | |
| return { | |
| "data_size": hit_results["data_size"], | |
| "artifacts": { | |
| "s3_keys": { | |
| "hit": hit_results["artifacts"]["s3_keys"], | |
| "message": message_results["artifacts"]["s3_keys"] | |
| } | |
| }, | |
| "training_time": hit_results["training_time"] + message_results["training_time"] | |
| } | |
| def validate_model_name_support(model_name, operation: str = "operation"): | |
| """Validate that model name is supported for the operation.""" | |
| # Currently supported model names for this API | |
| supported_names = list(ModelName.__members__.values()) | |
| if model_name not in supported_names: | |
| raise ValueError( | |
| f"Model name {model_name} not yet implemented for {operation}. " | |
| f"Currently supported: {[t.value for t in supported_names]}" | |
| ) | |
| async def process_json_to_dataframe(data, compliance_type: ComplianceType, debug_context: str = "PREDICT") -> Tuple[Any, str, str]: | |
| """Process JSON data to DataFrame and handle S3 upload.""" | |
| if compliance_type == ComplianceType.FIRCO: | |
| # Convert FIRCO JSON to DataFrame using dataset_utils | |
| df = convert_firco_json_to_dataframe(data) | |
| print(f"[{debug_context}] Firco JSON data converted to DataFrame. Shape: {df.shape}") | |
| else: | |
| print("[DEBUG] Non-Firco JSON data received for conversion to DataFrame") | |
| df = convert_non_firco_json_to_dataframe(data) | |
| print(f"[{debug_context}] Non-Firco JSON data converted to DataFrame. Shape: {df.shape}") | |
| # Check if dataframe is empty | |
| validate_dataframe_not_empty(df, "JSON data") | |
| # Save debug DataFrame to S3 instead of local | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| debug_filename = f"debug_json_data_{timestamp}.csv" | |
| # Convert DataFrame to CSV bytes | |
| csv_buffer = io.StringIO() | |
| df.to_csv(csv_buffer, index=False) | |
| csv_bytes = csv_buffer.getvalue().encode('utf-8') | |
| # Async upload debug CSV to S3 | |
| s3_key = s3_key_for_upload(debug_filename) | |
| await async_upload_file_to_s3(csv_bytes, s3_key) | |
| # Async write to temp file for processing | |
| file_path = await async_write_temp_file(csv_bytes, suffix='.csv') | |
| print(f"[{debug_context}] Debug DataFrame uploaded to S3: {debug_filename}") | |
| print(f"[{debug_context}] DataFrame columns: {df.columns.tolist()}") | |
| return df, file_path, s3_key | |
| def validate_csv_file(file: UploadFile): | |
| """Validate that the uploaded file is a CSV file.""" | |
| if not file.filename.endswith('.csv'): | |
| raise HTTPException(status_code=400, detail="Only CSV files are allowed") | |
| def validate_required_param(param_value: Any, param_name: str): | |
| """Validate that a required parameter is provided.""" | |
| if not param_value: | |
| raise HTTPException(status_code=400, detail=f"{param_name} is required") | |
| def validate_user_id(user_id: str): | |
| """Validate user ID format.""" | |
| if not user_id or not isinstance(user_id, str) or len(user_id.strip()) == 0: | |
| raise HTTPException(status_code=400, detail="Valid user_id is required") | |
| def validate_version(version: str): | |
| """Validate version format.""" | |
| if not version or not isinstance(version, str) or len(version.strip()) == 0: | |
| raise HTTPException(status_code=400, detail="Valid version is required") | |
| def validate_model_id(model_id: str): | |
| """Validate model ID format.""" | |
| if not model_id or not isinstance(model_id, str) or len(model_id.strip()) == 0: | |
| raise HTTPException(status_code=400, detail="Valid model_id is required") | |
| def validate_required_columns(df: pd.DataFrame, required_columns: List[str], context: str = "dataset"): | |
| """Validate that dataframe contains all required columns.""" | |
| missing_columns = [col for col in required_columns if col not in df.columns] | |
| if missing_columns: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Missing required columns in {context}: {missing_columns}" | |
| ) | |
| def format_model_list_response(models: List[Any]) -> Dict[str, List[Dict]]: | |
| """Format a list of models for API response.""" | |
| formatted_models = [] | |
| for model in models: | |
| if hasattr(model, 'model_dump'): | |
| formatted_models.append(model.model_dump()) | |
| elif hasattr(model, 'dict'): | |
| formatted_models.append(model.dict()) | |
| else: | |
| formatted_models.append(dict(model)) | |
| return {"models": formatted_models} | |
| def format_training_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]: | |
| """Format a list of training runs for API response.""" | |
| formatted_runs = [] | |
| for run in runs: | |
| if hasattr(run, 'model_dump'): | |
| formatted_runs.append(run.model_dump()) | |
| elif hasattr(run, 'dict'): | |
| formatted_runs.append(run.dict()) | |
| else: | |
| formatted_runs.append(dict(run)) | |
| return {"training_runs": formatted_runs} | |
| async def format_enhanced_training_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]: | |
| """ | |
| Format a list of training runs for API response with enhanced model information and S3 URLs. | |
| """ | |
| formatted_runs = [] | |
| for run in runs: | |
| # Convert run to dict | |
| if hasattr(run, 'model_dump'): | |
| run_dict = run.model_dump() | |
| elif hasattr(run, 'dict'): | |
| run_dict = run.dict() | |
| else: | |
| run_dict = dict(run) | |
| # Get model information | |
| model_info = None | |
| if run_dict.get('model_id'): | |
| try: | |
| model_info = await async_model_crud.get_by_model_id(run_dict['model_id']) | |
| except Exception as e: | |
| print(f"[WARNING] Failed to get model info for model_id {run_dict.get('model_id')}: {e}") | |
| # Add model_type and model_version from model info | |
| if model_info: | |
| run_dict['model_type'] = model_info.model_name | |
| run_dict['model_version'] = model_info.model_version | |
| # Enhance datasets with S3 URLs | |
| if 'datasets' in run_dict and run_dict['datasets']: | |
| datasets = run_dict['datasets'] | |
| # Add S3 URLs for data files | |
| if datasets.get('train'): | |
| datasets['train_url'] = get_s3_url(datasets['train']) | |
| if datasets.get('validation'): | |
| datasets['validation_url'] = get_s3_url(datasets['validation']) | |
| if datasets.get('test'): | |
| datasets['test_url'] = get_s3_url(datasets['test']) | |
| if datasets.get('s3_key'): | |
| datasets['data_file_url'] = get_s3_url(datasets['s3_key']) | |
| formatted_runs.append(run_dict) | |
| return {"training_runs": formatted_runs} | |
| async def format_universal_runs_response(runs: Union[Any, List[Any]], entity_type: str = "run") -> Union[Dict[str, Any], Dict[str, List[Dict]]]: | |
| """ | |
| Universal format function for single or multiple runs (training/validation/prediction) with enhanced model information and S3 URLs. | |
| Args: | |
| runs: Single run object or list of run objects | |
| entity_type: Type of entity ("training", "validation", "prediction", or "run") | |
| return_single: If True, returns single run dict; if False, returns {"<entity_type>_runs": [runs]} | |
| Returns: | |
| Dict: Either single run dict or {"<entity_type>_runs": [formatted_runs]} | |
| """ | |
| async def format_single_run(run): | |
| """Helper function to format a single run""" | |
| if not run: | |
| return None | |
| if hasattr(run, 'model_dump'): | |
| run_dict = run.model_dump() | |
| elif hasattr(run, 'dict'): | |
| run_dict = run.dict() | |
| else: | |
| run_dict = dict(run) | |
| model_info = None | |
| if run_dict.get('model_id'): | |
| try: | |
| model_info = await async_model_crud.get_by_model_id(run_dict['model_id']) | |
| except Exception as e: | |
| print(f"[WARNING] Failed to get model info for model_id {run_dict.get('model_id')}: {e}") | |
| if model_info: | |
| run_dict['model_type'] = next(model_name["label"] for model_name in model_names | |
| if model_name["code"] == model_info.model_name) | |
| run_dict['model_version'] = model_info.model_version | |
| # Transform based on entity type to match desired schema | |
| if entity_type == "training": | |
| return format_training_run_schema(run_dict) | |
| elif entity_type == "validation": | |
| return format_validation_run_schema(run_dict) | |
| elif entity_type == "prediction": | |
| return format_prediction_run_schema(run_dict) | |
| else: | |
| # Default behavior for backward compatibility | |
| if 'datasets' in run_dict and run_dict['datasets']: | |
| datasets = run_dict['datasets'] | |
| # Add S3 URLs for data files | |
| if datasets.get('train'): | |
| datasets['train_url'] = get_s3_url(datasets['train']) | |
| if datasets.get('validation'): | |
| datasets['validation_url'] = get_s3_url(datasets['validation']) | |
| if datasets.get('test'): | |
| datasets['test_url'] = get_s3_url(datasets['test']) | |
| if datasets.get('s3_key'): | |
| datasets['data_file_url'] = get_s3_url(datasets['s3_key']) | |
| return run_dict | |
| # Handle single run case | |
| if not isinstance(runs, list): | |
| formatted_run = await format_single_run(runs) | |
| if formatted_run is None: | |
| raise HTTPException(status_code=404, detail=f"{entity_type.capitalize()} run not found") | |
| # Return as a list with single item | |
| return {f"{entity_type}_runs": [formatted_run]} | |
| # Handle multiple runs case | |
| formatted_runs = [] | |
| for run in runs: | |
| formatted_run = await format_single_run(run) | |
| if formatted_run: | |
| formatted_runs.append(formatted_run) | |
| return {f"{entity_type}_runs": formatted_runs} | |
| def format_training_run_schema(run_dict: Dict[str, Any]) -> Dict[str, Any]: | |
| """Transform training run to match desired schema""" | |
| # Extract original fields | |
| datasets = run_dict.get('datasets') or {} | |
| data_size = run_dict.get('data_size') or {} | |
| training_config = run_dict.get('training_config') or {} | |
| if 'bert' in data_size: | |
| data_size = data_size['bert'] | |
| print("[DEBUG] Adjusted data_size for bert model: ", data_size) | |
| # Create transformed schema | |
| transformed = { | |
| "training_id": run_dict.get('training_id'), | |
| "user_name": run_dict.get('user_id'), # Map user_id to user_name | |
| "model_version": run_dict.get('model_version'), | |
| "model_type": run_dict.get('model_type'), | |
| "compliance_type": run_dict.get('compliance_type'), | |
| "status": run_dict.get('status'), | |
| "created_at": run_dict.get('created_at'), | |
| "datasets": { | |
| "data_size": { | |
| "total_records": data_size.get('total_records') | |
| }, | |
| "training_config": { | |
| "test_size": training_config.get('test_size', 0.2), | |
| "random_state": training_config.get('random_state', 42), | |
| "stratify": training_config.get('stratify', True), | |
| "class_weights": training_config.get('class_weights',None) | |
| }, | |
| "train_data": { | |
| "data_source": "aws_s3", | |
| "data_file": get_s3_url(datasets.get('s3_key')) or get_s3_url(datasets.get('train')) | |
| } | |
| } | |
| } | |
| return transformed | |
| def format_validation_run_schema(run_dict: Dict[str, Any]) -> Dict[str, Any]: | |
| """Transform validation run to match desired schema""" | |
| datasets = run_dict.get('datasets') or {} | |
| data_size = run_dict.get('data_size') or {} | |
| metrics = run_dict.get('metrics') or {} | |
| accuracy_file_level = None | |
| accuracies = [] | |
| classification_report = metrics.get("classification_report") or {} | |
| if run_dict.get('model_name') == "prime": | |
| direct_accuracy = classification_report.get("accuracy") or None | |
| if direct_accuracy is not None and isinstance(direct_accuracy, (int, float)): | |
| accuracy_file_level = int(direct_accuracy * 100) | |
| else: | |
| for label, label_metrics in classification_report.items(): | |
| if isinstance(label_metrics, dict): | |
| acc = label_metrics.get("accuracy") | |
| if isinstance(acc, (int, float)): | |
| accuracies.append(acc) | |
| if accuracies: | |
| avg_accuracy = sum(accuracies) / len(accuracies) | |
| accuracy_file_level = int(avg_accuracy * 100) | |
| transformed = { | |
| "validation_id": run_dict.get('validation_id'), | |
| "user_name": run_dict.get('user_id'), # Map user_id to user_name | |
| "model_version": run_dict.get('model_version'), | |
| "model_type": run_dict.get('model_type'), | |
| "compliance_type": run_dict.get('compliance_type'), | |
| "status": run_dict.get('status'), | |
| "accuracy_file_level": accuracy_file_level, | |
| "created_at": run_dict.get('created_at'), | |
| "datasets": { | |
| "data_size": { | |
| "total_records": data_size.get('total_records'), | |
| }, | |
| "validate_data_input": { | |
| "data_source": "aws_s3", | |
| "data_file": get_s3_url(datasets.get('s3_key')) or get_s3_url(datasets.get('validation')) | |
| }, | |
| "validate_data_output": { | |
| "data_source": "aws_s3", | |
| "data_file": get_s3_url(run_dict.get('results_s3_key')) # Use results S3 key for output JSON | |
| } | |
| } | |
| } | |
| # print("Transformed is, ",transformed) | |
| return transformed | |
| def format_prediction_run_schema(run_dict: Dict[str, Any]) -> Dict[str, Any]: | |
| """Transform prediction run to match desired schema""" | |
| # Extract original fields | |
| input_data = run_dict.get('input_data') or {} | |
| predictions = run_dict.get('predictions') or {} | |
| # Calculate accuracy from predictions if available | |
| accuracy_file_level = None | |
| if predictions and isinstance(predictions, dict): | |
| # Get the actual predictions list from the dict | |
| predictions_list = predictions.get('predictions', []) | |
| if predictions_list and isinstance(predictions_list, list): | |
| # Calculate average accuracy from predictions | |
| total_accuracy = 0 | |
| count = 0 | |
| for pred in predictions_list: | |
| if isinstance(pred, dict): | |
| # Try different accuracy field names | |
| accuracy_value = pred.get('accuracy') or pred.get('probability') | |
| if accuracy_value is not None: | |
| # If it's a probability (0-1), convert to percentage | |
| if accuracy_value <= 1.0: | |
| accuracy_value = accuracy_value * 100 | |
| total_accuracy += accuracy_value | |
| count += 1 | |
| if count > 0: | |
| accuracy_file_level = int(total_accuracy / count) | |
| # Create transformed schema | |
| transformed = { | |
| "prediction_id": run_dict.get('prediction_id'), | |
| "training_id": run_dict.get('training_id'), | |
| "user_name": run_dict.get('user_id'), # Map user_id to user_name | |
| "model_version": run_dict.get('model_version'), | |
| "compliance_type": run_dict.get('compliance_type'), | |
| "model_type": run_dict.get('model_type'), | |
| "status": run_dict.get('status'), | |
| "accuracy_file_level": accuracy_file_level, | |
| "created_at": run_dict.get('created_at'), | |
| "datasets": { | |
| "total_records": predictions.get('total_samples') or predictions.get('total_predictions') or input_data.get('record_count'), | |
| "predict_data_input": { | |
| "data_source": "aws_s3", | |
| "data_file": get_s3_url(input_data.get('s3_key')) | |
| }, | |
| "predict_data_output": { | |
| "data_source": "aws_s3", | |
| "data_file": get_s3_url(predictions.get('results_s3_key')) or get_s3_url(run_dict.get('results_s3_key')) # Try predictions dict first, then run dict | |
| } | |
| } | |
| } | |
| return transformed | |
| def format_validation_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]: | |
| """Format a list of validation runs for API response.""" | |
| formatted_runs = [] | |
| for run in runs: | |
| if hasattr(run, 'model_dump'): | |
| formatted_runs.append(run.model_dump()) | |
| elif hasattr(run, 'dict'): | |
| formatted_runs.append(run.dict()) | |
| else: | |
| formatted_runs.append(dict(run)) | |
| return {"validation_runs": formatted_runs} | |
| def format_prediction_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]: | |
| """Format a list of prediction runs for API response.""" | |
| formatted_runs = [] | |
| for run in runs: | |
| if hasattr(run, 'model_dump'): | |
| formatted_runs.append(run.model_dump()) | |
| elif hasattr(run, 'dict'): | |
| formatted_runs.append(run.dict()) | |
| else: | |
| formatted_runs.append(dict(run)) | |
| return {"prediction_runs": formatted_runs} | |
| def format_single_model_response(model: Any) -> Dict[str, Any]: | |
| """Format a single model for API response.""" | |
| if not model: | |
| raise HTTPException(status_code=404, detail="Model not found") | |
| if hasattr(model, 'model_dump'): | |
| return model.model_dump() | |
| elif hasattr(model, 'dict'): | |
| return model.dict() | |
| else: | |
| return dict(model) | |
| def generate_model_id() -> str: | |
| """Generate a unique model ID.""" | |
| return str(uuid.uuid4()) | |
| def generate_alert_id() -> str: | |
| """Human-readable alert ID: ALT-YYYYMMDD-<short_uuid>.""" | |
| date_part = datetime.now(timezone.utc).strftime("%Y%m%d") | |
| uid_part = uuid.uuid4().hex[:8] | |
| return f"ALT-{date_part}-{uid_part}" | |
| def safe_float_conversion(value: Any, default: float = 0.0) -> float: | |
| """Safely convert value to float with default fallback.""" | |
| try: | |
| if pd.isna(value): | |
| return default | |
| return float(value) | |
| except (ValueError, TypeError): | |
| return default | |
| def safe_int_conversion(value: Any, default: int = 0) -> int: | |
| """Safely convert value to int with default fallback.""" | |
| try: | |
| if pd.isna(value): | |
| return default | |
| return int(value) | |
| except (ValueError, TypeError): | |
| return default | |
| def calculate_progress_percentage(current_epoch: int, total_epochs: int) -> float: | |
| """Calculate progress percentage for training.""" | |
| if total_epochs <= 0: | |
| return 0.0 | |
| return min(100.0, (current_epoch / total_epochs) * 100.0) | |
| def estimate_completion_time(start_time: datetime, current_epoch: int, total_epochs: int) -> Optional[datetime]: | |
| """Estimate completion time based on current progress.""" | |
| if current_epoch <= 0 or total_epochs <= 0: | |
| return None | |
| elapsed_time = datetime.utcnow() - start_time | |
| time_per_epoch = elapsed_time / current_epoch | |
| remaining_epochs = total_epochs - current_epoch | |
| estimated_remaining_time = time_per_epoch * remaining_epochs | |
| return datetime.utcnow() + estimated_remaining_time | |
| def clean_model_response(model_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Clean model response by converting ObjectId to string and handling dates.""" | |
| cleaned_data = {} | |
| for key, value in model_data.items(): | |
| if key == "_id": | |
| cleaned_data["id"] = str(value) | |
| elif isinstance(value, datetime): | |
| cleaned_data[key] = value.isoformat() | |
| elif isinstance(value, list): | |
| # Handle list of runs | |
| cleaned_list = [] | |
| for item in value: | |
| if isinstance(item, dict): | |
| cleaned_item = clean_model_response(item) | |
| cleaned_list.append(cleaned_item) | |
| else: | |
| cleaned_list.append(item) | |
| cleaned_data[key] = cleaned_list | |
| else: | |
| cleaned_data[key] = value | |
| return cleaned_data | |
| def prepare_prediction_response(predictions: List[Dict[str, Any]], | |
| transaction_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]: | |
| """Prepare prediction response with proper formatting.""" | |
| formatted_predictions = [] | |
| for i, pred in enumerate(predictions): | |
| transaction_id = transaction_ids[i] if transaction_ids and i < len(transaction_ids) else f"transaction_{i}" | |
| formatted_pred = { | |
| "transaction_id": transaction_id, | |
| "predictions": pred | |
| } | |
| formatted_predictions.append(formatted_pred) | |
| return formatted_predictions | |
| def handle_missing_values(df: pd.DataFrame, fill_strategy: str = "unknown") -> pd.DataFrame: | |
| """Handle missing values in dataframe based on strategy.""" | |
| df_copy = df.copy() | |
| if fill_strategy == "unknown": | |
| df_copy = df_copy.fillna("Unknown") | |
| elif fill_strategy == "zero": | |
| df_copy = df_copy.fillna(0) | |
| elif fill_strategy == "drop": | |
| df_copy = df_copy.dropna() | |
| else: | |
| df_copy = df_copy.fillna("Unknown") | |
| return df_copy | |
| def validate_config_parameters(config: Dict[str, Any]) -> Dict[str, Any]: | |
| """Validate and sanitize configuration parameters.""" | |
| validated_config = {} | |
| # Validate batch_size | |
| batch_size = config.get("batch_size", 8) | |
| validated_config["batch_size"] = max(1, min(64, safe_int_conversion(batch_size, 8))) | |
| # Validate learning_rate | |
| learning_rate = config.get("learning_rate", 2e-5) | |
| validated_config["learning_rate"] = max(1e-6, min(1e-3, safe_float_conversion(learning_rate, 2e-5))) | |
| # Validate num_epochs | |
| num_epochs = config.get("num_epochs", 2) | |
| validated_config["num_epochs"] = max(1, min(20, safe_int_conversion(num_epochs, 2))) | |
| # Validate max_length | |
| max_length = config.get("max_length", 128) | |
| validated_config["max_length"] = max(32, min(512, safe_int_conversion(max_length, 128))) | |
| # Validate model_name | |
| model_name = config.get("model_name", "roberta-base") | |
| if not isinstance(model_name, str) or not model_name.strip(): | |
| validated_config["model_name"] = "roberta-base" | |
| else: | |
| validated_config["model_name"] = model_name.strip() | |
| # Validate random_state | |
| random_state = config.get("random_state", 42) | |
| validated_config["random_state"] = safe_int_conversion(random_state, 42) | |
| return validated_config | |
| def create_error_response(error_message: str, status_code: int = 500) -> HTTPException: | |
| """Create standardized error response.""" | |
| return HTTPException(status_code=status_code, detail=error_message) | |
| def generate_training_id() -> str: | |
| """Generate a unique training ID.""" | |
| return generate_operation_id() | |
| def generate_run_id() -> str: | |
| """Generate a unique run ID.""" | |
| return generate_operation_id() | |
| async def process_file_upload(file_bytes: bytes, filename: str, context: str) -> Tuple[str, str]: | |
| """Common file processing workflow for S3 upload and temp file creation.""" | |
| s3_key = s3_key_for_upload(filename) | |
| # Async upload to S3 | |
| await async_upload_file_to_s3(file_bytes, s3_key) | |
| print(f"[{context}] File uploaded to S3: {s3_key}") | |
| # Async write to temp file | |
| file_path = await async_write_temp_file(file_bytes, suffix='.csv') | |
| print(f"[{context}] File copied to temp for processing: {file_path}") | |
| return s3_key, file_path | |
| async def validate_model_access(user_id: str, model_name: ModelName, version: str, level: str = None) -> Any: | |
| """Validate that the version belongs to the given user_id.""" | |
| from app.services.train import TrainingService | |
| training_service = TrainingService() | |
| print(f"[DEBUG] validate_model_access is called with user_id={user_id}, model_name={model_name}, version={version}, level={level}") | |
| if version.lower() == "latest": | |
| return True | |
| if model_name == ModelName.ROBERTA_COMPLIANCE: | |
| training_run = await training_service.get_by_version(user_id, model_name, version,"multi_output") | |
| else: | |
| training_run = await training_service.get_by_version(user_id, model_name, version,level) | |
| if not training_run: | |
| raise HTTPException(status_code=404, detail=f"Training run with id '{version}' not found") | |
| if training_run.user_id != user_id: | |
| raise HTTPException(status_code=403, detail=f"User '{user_id}' does not have access to version '{version}'") | |
| print(f"[DEBUG] User '{user_id}' has access to version '{version}'") | |
| return training_run | |
| async def get_training_id_from_version(user_id: str, version: str, model_name: str, compliance_type: str, level=None) -> str: | |
| """ | |
| Resolve 'latest' tag to actual training_id for any model type. | |
| Args: | |
| user_id: User identifier | |
| training_id: Training ID (can be 'latest' or actual ID) | |
| model_name: Model name | |
| compliance_type: Compliance type | |
| Returns: | |
| str: Resolved training_id | |
| """ | |
| from app.services.train import TrainingService | |
| print(f"[LATEST_RESOLVE] Checking version: {version} for user_id: {user_id}, model_name: {model_name}, compliance_type: {compliance_type} , level: {level}") | |
| training_service = TrainingService() | |
| if version.lower() == "latest": | |
| if compliance_type == ComplianceType.NON_FIRCO: | |
| user_and_model = await async_user_and_model_crud.get_by_user_and_model(user_id, model_name, compliance_type, level="multi_output") | |
| else: | |
| user_and_model = await async_user_and_model_crud.get_by_user_and_model(user_id, model_name, compliance_type, level=level) | |
| if user_and_model: | |
| resolved_id = user_and_model.latest_version_training_id | |
| print(f"[LATEST_RESOLVE] Resolved 'latest' to training_id: {resolved_id} for {compliance_type}") | |
| return resolved_id | |
| else: | |
| print(f"[LATEST_RESOLVE] No trained {compliance_type} model found for user {user_id}") | |
| raise HTTPException(status_code=404, detail=f"No trained {compliance_type} model found for this user") | |
| else: | |
| print(f"[DEBUG] trying to get the version with parameters user_id={user_id}, model_name={model_name}, version={version}, level={level}") | |
| if compliance_type == ComplianceType.NON_FIRCO: | |
| training_run = await training_service.get_by_version(user_id, model_name, int(version), level="multi_output") | |
| else: | |
| training_run = await training_service.get_by_version(user_id, model_name, int(version), level) | |
| if not training_run: | |
| raise HTTPException(status_code=404, detail=f"Training run with version '{version}' not found") | |
| print(f"[LATEST_RESOLVE] Using provided version: {version}") | |
| return training_run.training_id | |
| async def get_validate_result_json_s3_keys( | |
| validation_id:str, | |
| compliance_type: ComplianceType, | |
| model_name: ModelName, | |
| classification_report: Dict[str, Any], | |
| training_id: str, | |
| user_id: str, | |
| filename: str, | |
| file_download_url: str, | |
| detailed_predictions: Any, | |
| ): | |
| # Store results JSON to S3 first to get S3 key | |
| results_s3_key = s3_key_for_upload(f"{validation_id}_results.json") | |
| print(f"[DEBUG] Generated S3 key: {results_s3_key}") | |
| try: | |
| frontend_model_name = "" | |
| # model_names is a simple list, not organized by compliance_type | |
| for dictonary in model_names: | |
| code = dictonary['code'] | |
| print(f"[DEBUG] Checking code: {code} against model_name: {model_name}") | |
| if code == model_name or code == str(model_name): | |
| frontend_model_name = dictonary['label'] | |
| print(f"[DEBUG] Found frontend model name: {frontend_model_name} for model_name: {model_name}") | |
| break | |
| if not frontend_model_name: | |
| print(f"[WARNING] No frontend model name found for {model_name}, using default") | |
| frontend_model_name = str(model_name) | |
| # Prepare temporary response for JSON storage | |
| temp_response = ValidationResponse( | |
| message="Validation completed successfully", | |
| metrics={"classification_report": classification_report}, | |
| training_id=training_id, | |
| user_id=user_id, | |
| model_name=frontend_model_name, | |
| compliance_type=compliance_type, | |
| # level=level, | |
| file=filename, | |
| dataset_download_url=file_download_url, | |
| results_download_url=f"https://{S3_BUCKET_NAME}.s3.{AWS_REGION}.amazonaws.com/{results_s3_key}" if results_s3_key else None, | |
| detailed_predictions=detailed_predictions | |
| ) | |
| # store the result in the json file and upload to s3 | |
| result_json_path = MODEL_SAVE_DIR / f"{validation_id}_results.json" | |
| with open(result_json_path, "w") as f: | |
| response_dict = temp_response.model_dump() | |
| json.dump(response_dict, f, indent=2) | |
| with open(result_json_path, "rb") as f: | |
| file_content = f.read() | |
| await async_upload_file_to_s3(file_content, results_s3_key) | |
| except Exception as e: | |
| print(f"[ERROR] Failed to save validation results to S3: {str(e)}") | |
| print(f"[ERROR] Error type: {type(e).__name__}") | |
| print(f"[ERROR] Full traceback:") | |
| traceback.print_exc() | |
| return results_s3_key | |
| async def get_prediction_result_json_s3_keys( | |
| prediction_id:str, | |
| prediction_results: List[Dict[str, Any]], | |
| file_download_url: str, | |
| filename: str, | |
| training_id: str, | |
| compliance_type: ComplianceType, | |
| model_name: ModelName, | |
| source_type:str, | |
| ): | |
| try: | |
| result_json_path = MODEL_SAVE_DIR / f"{prediction_id}_results.json" | |
| results_s3_key = s3_key_for_upload(f"{prediction_id}_results.json") | |
| frontend_model_name = "" | |
| print(f"[DEBUG] model_names structure: {model_names}") | |
| # model_names is a simple list, not organized by compliance_type | |
| for dictonary in model_names: | |
| code = dictonary['code'] | |
| print(f"[DEBUG] Checking code: {code} against model_name: {model_name}") | |
| if code == model_name or code == str(model_name): | |
| frontend_model_name = dictonary['label'] | |
| print(f"[DEBUG] Found frontend model name: {frontend_model_name} for model_name: {model_name}") | |
| break | |
| # Prepare response for JSON storage | |
| temp_response = BatchPredictionResponse( | |
| message="Prediction completed successfully", | |
| predictions=prediction_results, | |
| training_id=training_id, | |
| model_name=frontend_model_name, | |
| compliance_type=compliance_type, | |
| source_type=source_type, | |
| file=filename, | |
| dataset_download_url=file_download_url, | |
| results_download_url=f"https://{S3_BUCKET_NAME}.s3.{AWS_REGION}.amazonaws.com/{results_s3_key}" if results_s3_key else None | |
| ) | |
| with open(result_json_path, "w") as f: | |
| json.dump(temp_response.model_dump(), f) | |
| with open(result_json_path, "rb") as f: | |
| await async_upload_file_to_s3(f.read(), results_s3_key) | |
| print(f"[PREDICT] Results JSON uploaded to S3: {results_s3_key}") | |
| except Exception as e: | |
| print(f"Failed to save prediction results to S3: {str(e)}") | |
| return results_s3_key | |