import json import traceback from fastapi import HTTPException, UploadFile from typing import Optional, List, Dict, Any, Tuple, Union import uuid from datetime import datetime, timezone import pandas as pd from app.schemas.prediction import BatchPredictionResponse from app.schemas.validation import ValidationResponse from app.services.auth import UserService from app.core.enums import * from app.core.constant import * from app.utils.s3_utils import * from app.services.models import async_model_crud from app.utils.dataset_utils import convert_firco_json_to_dataframe, convert_non_firco_json_to_dataframe import io from app.core.enums import ModelName from app.ml.models.firco_ensemble_model import FircoEnsembleXGBModel from app.ml.models.firco_tfidf_model import FircoTFIDFXGBModel from app.ml.models.non_firco_roberta_model import NonFircoRobertaModel from app.services.user_and_model import async_user_and_model_crud user_service = UserService() def validate_level(level: str): """Validate that the level is one of the allowed values.""" if level not in ("hit", "message", "both"): raise HTTPException(status_code=400, detail="level must be 'hit', 'message', or 'both'") def format_prediction_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]: """Format a list of prediction runs for API response.""" return {"prediction_runs": [run.model_dump() for run in runs]} def format_single_run_response(run: Any, entity_name: str) -> Dict[str, Any]: """Format a single run (training/validation/prediction) for API response.""" if not run: raise HTTPException(status_code=404, detail=f"{entity_name} not found") return run.model_dump() ## To be changed later (should use version instead of training) async def get_version_from_model_id(model_id: str, async_model_crud) -> str: from app.services.train import TrainingService training_service = TrainingService() """Helper function to get training_id from model_id.""" model_record = await async_model_crud.get_by_model_id(model_id) if not model_record: raise HTTPException(status_code=404, detail=f"Model with model_id '{model_id}' not found") training_runs = await training_service.list_by_model(model_id) if not training_runs: raise HTTPException(status_code=404, detail=f"No training runs found for model_id '{model_id}'") return training_runs[0].training_id # Latest training run def get_model_files_for_level(level: str, training_id: str) -> List[tuple]: """Generate list of (s3_filename, zip_filename) for a given level.""" return [ (f"xgb_model_{level}_{training_id}.pkl", f"{level}_model.pkl"), (f"feature_encoder_{level}_{training_id}.pkl", f"{level}_feature_encoder.pkl"), (f"label_encoder_{level}_{training_id}.pkl", f"{level}_label_encoder.pkl") ] def get_download_files_and_name(level: str, training_id: str, model_id: str) -> tuple: """Get files to download and zip filename based on level.""" if level == "both": files_to_download = (get_model_files_for_level("hit", training_id) + get_model_files_for_level("message", training_id)) zip_filename = f"model_{model_id}_both_{training_id}.zip" else: files_to_download = get_model_files_for_level(level, training_id) zip_filename = f"model_{model_id}_{level}_{training_id}.zip" return files_to_download, zip_filename # Additional helper functions for model_utils.py duplication reduction def validate_model_name_from_string(model_name_str: str): """Convert and validate model type string to enum.""" try: return ModelName(model_name_str) except ValueError: raise HTTPException( status_code=400, detail=f"Invalid model_name: {model_name_str}. Supported types: {[e.value for e in ModelName]}" ) def validate_dataframe_not_empty(df, context: str = "data"): """Validate that dataframe is not empty.""" if df.empty or len(df) == 0: raise HTTPException( status_code=400, detail=f"The {context} is empty or contains no valid data after processing" ) def validate_prediction_csv_params(file: UploadFile, level: str, training_id: str): """Validate all CSV prediction parameters at once.""" validate_required_param(training_id, "training_id") validate_level(level) validate_required_param(file, "file") validate_csv_file(file) def validate_prediction_json_params(training_id: str, level: str): """Validate all JSON prediction parameters at once.""" validate_required_param(training_id, "training_id") validate_level(level) def generate_operation_id() -> str: """Generate a unique ID for operations (training, validation, prediction).""" return datetime.now().strftime("%Y%m%d_%H%M%S") + '_' + str(uuid.uuid4())[:12] async def create_database_record(crud_instance, record_data, record_type: str): """Generic database record creation with error handling.""" try: record = await crud_instance.create(record_data) print(f"[DB] {record_type} record created: {record.id}") return record except Exception as e: print(f"[DB] Error creating {record_type} record: {str(e)}") raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") # TO DO apply logic to get ensemble def get_model_instances(training_id=None, model_name = None, targets = None) -> Dict[str, Any]: """Get model instances based on level (hit, message, or both) and model type.""" if model_name == ModelName.FIRCO_ENSEMBLE_XGB: ModelClass = FircoEnsembleXGBModel elif model_name == ModelName.FIRCO_TFIDF_XGB: ModelClass = FircoTFIDFXGBModel elif model_name == ModelName.ROBERTA_COMPLIANCE: ModelClass = NonFircoRobertaModel # print("[DEBUG] ModelClass selected is ", ModelClass) return ModelClass(training_id=training_id, targets=targets) def combine_both_level_results(hit_results: Dict, message_results: Dict, operation: str) -> Dict: """Combine results from both hit and message level operations.""" if operation == "validation": combined_accuracy = (hit_results["accuracy"] + message_results["accuracy"]) / 2 return { "accuracy": combined_accuracy, "metrics": { "hit_level": hit_results["classification_report"], "message_level": message_results["classification_report"] }, "data_size": {"total_records": hit_results["total_samples"]}, "hit_results": hit_results, "message_results": message_results } elif operation == "training": print(f"[DEBUG] Combining training results for both levels resultant artifacts is { {**hit_results['artifacts']['s3_keys'], **message_results['artifacts']['s3_keys']} }") return { "data_size": hit_results["data_size"], "artifacts": { "s3_keys": { "hit": hit_results["artifacts"]["s3_keys"], "message": message_results["artifacts"]["s3_keys"] } }, "training_time": hit_results["training_time"] + message_results["training_time"] } def validate_model_name_support(model_name, operation: str = "operation"): """Validate that model name is supported for the operation.""" # Currently supported model names for this API supported_names = list(ModelName.__members__.values()) if model_name not in supported_names: raise ValueError( f"Model name {model_name} not yet implemented for {operation}. " f"Currently supported: {[t.value for t in supported_names]}" ) async def process_json_to_dataframe(data, compliance_type: ComplianceType, debug_context: str = "PREDICT") -> Tuple[Any, str, str]: """Process JSON data to DataFrame and handle S3 upload.""" if compliance_type == ComplianceType.FIRCO: # Convert FIRCO JSON to DataFrame using dataset_utils df = convert_firco_json_to_dataframe(data) print(f"[{debug_context}] Firco JSON data converted to DataFrame. Shape: {df.shape}") else: print("[DEBUG] Non-Firco JSON data received for conversion to DataFrame") df = convert_non_firco_json_to_dataframe(data) print(f"[{debug_context}] Non-Firco JSON data converted to DataFrame. Shape: {df.shape}") # Check if dataframe is empty validate_dataframe_not_empty(df, "JSON data") # Save debug DataFrame to S3 instead of local timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") debug_filename = f"debug_json_data_{timestamp}.csv" # Convert DataFrame to CSV bytes csv_buffer = io.StringIO() df.to_csv(csv_buffer, index=False) csv_bytes = csv_buffer.getvalue().encode('utf-8') # Async upload debug CSV to S3 s3_key = s3_key_for_upload(debug_filename) await async_upload_file_to_s3(csv_bytes, s3_key) # Async write to temp file for processing file_path = await async_write_temp_file(csv_bytes, suffix='.csv') print(f"[{debug_context}] Debug DataFrame uploaded to S3: {debug_filename}") print(f"[{debug_context}] DataFrame columns: {df.columns.tolist()}") return df, file_path, s3_key def validate_csv_file(file: UploadFile): """Validate that the uploaded file is a CSV file.""" if not file.filename.endswith('.csv'): raise HTTPException(status_code=400, detail="Only CSV files are allowed") def validate_required_param(param_value: Any, param_name: str): """Validate that a required parameter is provided.""" if not param_value: raise HTTPException(status_code=400, detail=f"{param_name} is required") def validate_user_id(user_id: str): """Validate user ID format.""" if not user_id or not isinstance(user_id, str) or len(user_id.strip()) == 0: raise HTTPException(status_code=400, detail="Valid user_id is required") def validate_version(version: str): """Validate version format.""" if not version or not isinstance(version, str) or len(version.strip()) == 0: raise HTTPException(status_code=400, detail="Valid version is required") def validate_model_id(model_id: str): """Validate model ID format.""" if not model_id or not isinstance(model_id, str) or len(model_id.strip()) == 0: raise HTTPException(status_code=400, detail="Valid model_id is required") def validate_required_columns(df: pd.DataFrame, required_columns: List[str], context: str = "dataset"): """Validate that dataframe contains all required columns.""" missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: raise HTTPException( status_code=400, detail=f"Missing required columns in {context}: {missing_columns}" ) def format_model_list_response(models: List[Any]) -> Dict[str, List[Dict]]: """Format a list of models for API response.""" formatted_models = [] for model in models: if hasattr(model, 'model_dump'): formatted_models.append(model.model_dump()) elif hasattr(model, 'dict'): formatted_models.append(model.dict()) else: formatted_models.append(dict(model)) return {"models": formatted_models} def format_training_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]: """Format a list of training runs for API response.""" formatted_runs = [] for run in runs: if hasattr(run, 'model_dump'): formatted_runs.append(run.model_dump()) elif hasattr(run, 'dict'): formatted_runs.append(run.dict()) else: formatted_runs.append(dict(run)) return {"training_runs": formatted_runs} async def format_enhanced_training_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]: """ Format a list of training runs for API response with enhanced model information and S3 URLs. """ formatted_runs = [] for run in runs: # Convert run to dict if hasattr(run, 'model_dump'): run_dict = run.model_dump() elif hasattr(run, 'dict'): run_dict = run.dict() else: run_dict = dict(run) # Get model information model_info = None if run_dict.get('model_id'): try: model_info = await async_model_crud.get_by_model_id(run_dict['model_id']) except Exception as e: print(f"[WARNING] Failed to get model info for model_id {run_dict.get('model_id')}: {e}") # Add model_type and model_version from model info if model_info: run_dict['model_type'] = model_info.model_name run_dict['model_version'] = model_info.model_version # Enhance datasets with S3 URLs if 'datasets' in run_dict and run_dict['datasets']: datasets = run_dict['datasets'] # Add S3 URLs for data files if datasets.get('train'): datasets['train_url'] = get_s3_url(datasets['train']) if datasets.get('validation'): datasets['validation_url'] = get_s3_url(datasets['validation']) if datasets.get('test'): datasets['test_url'] = get_s3_url(datasets['test']) if datasets.get('s3_key'): datasets['data_file_url'] = get_s3_url(datasets['s3_key']) formatted_runs.append(run_dict) return {"training_runs": formatted_runs} async def format_universal_runs_response(runs: Union[Any, List[Any]], entity_type: str = "run") -> Union[Dict[str, Any], Dict[str, List[Dict]]]: """ Universal format function for single or multiple runs (training/validation/prediction) with enhanced model information and S3 URLs. Args: runs: Single run object or list of run objects entity_type: Type of entity ("training", "validation", "prediction", or "run") return_single: If True, returns single run dict; if False, returns {"_runs": [runs]} Returns: Dict: Either single run dict or {"_runs": [formatted_runs]} """ async def format_single_run(run): """Helper function to format a single run""" if not run: return None if hasattr(run, 'model_dump'): run_dict = run.model_dump() elif hasattr(run, 'dict'): run_dict = run.dict() else: run_dict = dict(run) model_info = None if run_dict.get('model_id'): try: model_info = await async_model_crud.get_by_model_id(run_dict['model_id']) except Exception as e: print(f"[WARNING] Failed to get model info for model_id {run_dict.get('model_id')}: {e}") if model_info: run_dict['model_type'] = next(model_name["label"] for model_name in model_names if model_name["code"] == model_info.model_name) run_dict['model_version'] = model_info.model_version # Transform based on entity type to match desired schema if entity_type == "training": return format_training_run_schema(run_dict) elif entity_type == "validation": return format_validation_run_schema(run_dict) elif entity_type == "prediction": return format_prediction_run_schema(run_dict) else: # Default behavior for backward compatibility if 'datasets' in run_dict and run_dict['datasets']: datasets = run_dict['datasets'] # Add S3 URLs for data files if datasets.get('train'): datasets['train_url'] = get_s3_url(datasets['train']) if datasets.get('validation'): datasets['validation_url'] = get_s3_url(datasets['validation']) if datasets.get('test'): datasets['test_url'] = get_s3_url(datasets['test']) if datasets.get('s3_key'): datasets['data_file_url'] = get_s3_url(datasets['s3_key']) return run_dict # Handle single run case if not isinstance(runs, list): formatted_run = await format_single_run(runs) if formatted_run is None: raise HTTPException(status_code=404, detail=f"{entity_type.capitalize()} run not found") # Return as a list with single item return {f"{entity_type}_runs": [formatted_run]} # Handle multiple runs case formatted_runs = [] for run in runs: formatted_run = await format_single_run(run) if formatted_run: formatted_runs.append(formatted_run) return {f"{entity_type}_runs": formatted_runs} def format_training_run_schema(run_dict: Dict[str, Any]) -> Dict[str, Any]: """Transform training run to match desired schema""" # Extract original fields datasets = run_dict.get('datasets') or {} data_size = run_dict.get('data_size') or {} training_config = run_dict.get('training_config') or {} if 'bert' in data_size: data_size = data_size['bert'] print("[DEBUG] Adjusted data_size for bert model: ", data_size) # Create transformed schema transformed = { "training_id": run_dict.get('training_id'), "user_name": run_dict.get('user_id'), # Map user_id to user_name "model_version": run_dict.get('model_version'), "model_type": run_dict.get('model_type'), "compliance_type": run_dict.get('compliance_type'), "status": run_dict.get('status'), "created_at": run_dict.get('created_at'), "datasets": { "data_size": { "total_records": data_size.get('total_records') }, "training_config": { "test_size": training_config.get('test_size', 0.2), "random_state": training_config.get('random_state', 42), "stratify": training_config.get('stratify', True), "class_weights": training_config.get('class_weights',None) }, "train_data": { "data_source": "aws_s3", "data_file": get_s3_url(datasets.get('s3_key')) or get_s3_url(datasets.get('train')) } } } return transformed def format_validation_run_schema(run_dict: Dict[str, Any]) -> Dict[str, Any]: """Transform validation run to match desired schema""" datasets = run_dict.get('datasets') or {} data_size = run_dict.get('data_size') or {} metrics = run_dict.get('metrics') or {} accuracy_file_level = None accuracies = [] classification_report = metrics.get("classification_report") or {} if run_dict.get('model_name') == "prime": direct_accuracy = classification_report.get("accuracy") or None if direct_accuracy is not None and isinstance(direct_accuracy, (int, float)): accuracy_file_level = int(direct_accuracy * 100) else: for label, label_metrics in classification_report.items(): if isinstance(label_metrics, dict): acc = label_metrics.get("accuracy") if isinstance(acc, (int, float)): accuracies.append(acc) if accuracies: avg_accuracy = sum(accuracies) / len(accuracies) accuracy_file_level = int(avg_accuracy * 100) transformed = { "validation_id": run_dict.get('validation_id'), "user_name": run_dict.get('user_id'), # Map user_id to user_name "model_version": run_dict.get('model_version'), "model_type": run_dict.get('model_type'), "compliance_type": run_dict.get('compliance_type'), "status": run_dict.get('status'), "accuracy_file_level": accuracy_file_level, "created_at": run_dict.get('created_at'), "datasets": { "data_size": { "total_records": data_size.get('total_records'), }, "validate_data_input": { "data_source": "aws_s3", "data_file": get_s3_url(datasets.get('s3_key')) or get_s3_url(datasets.get('validation')) }, "validate_data_output": { "data_source": "aws_s3", "data_file": get_s3_url(run_dict.get('results_s3_key')) # Use results S3 key for output JSON } } } # print("Transformed is, ",transformed) return transformed def format_prediction_run_schema(run_dict: Dict[str, Any]) -> Dict[str, Any]: """Transform prediction run to match desired schema""" # Extract original fields input_data = run_dict.get('input_data') or {} predictions = run_dict.get('predictions') or {} # Calculate accuracy from predictions if available accuracy_file_level = None if predictions and isinstance(predictions, dict): # Get the actual predictions list from the dict predictions_list = predictions.get('predictions', []) if predictions_list and isinstance(predictions_list, list): # Calculate average accuracy from predictions total_accuracy = 0 count = 0 for pred in predictions_list: if isinstance(pred, dict): # Try different accuracy field names accuracy_value = pred.get('accuracy') or pred.get('probability') if accuracy_value is not None: # If it's a probability (0-1), convert to percentage if accuracy_value <= 1.0: accuracy_value = accuracy_value * 100 total_accuracy += accuracy_value count += 1 if count > 0: accuracy_file_level = int(total_accuracy / count) # Create transformed schema transformed = { "prediction_id": run_dict.get('prediction_id'), "training_id": run_dict.get('training_id'), "user_name": run_dict.get('user_id'), # Map user_id to user_name "model_version": run_dict.get('model_version'), "compliance_type": run_dict.get('compliance_type'), "model_type": run_dict.get('model_type'), "status": run_dict.get('status'), "accuracy_file_level": accuracy_file_level, "created_at": run_dict.get('created_at'), "datasets": { "total_records": predictions.get('total_samples') or predictions.get('total_predictions') or input_data.get('record_count'), "predict_data_input": { "data_source": "aws_s3", "data_file": get_s3_url(input_data.get('s3_key')) }, "predict_data_output": { "data_source": "aws_s3", "data_file": get_s3_url(predictions.get('results_s3_key')) or get_s3_url(run_dict.get('results_s3_key')) # Try predictions dict first, then run dict } } } return transformed def format_validation_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]: """Format a list of validation runs for API response.""" formatted_runs = [] for run in runs: if hasattr(run, 'model_dump'): formatted_runs.append(run.model_dump()) elif hasattr(run, 'dict'): formatted_runs.append(run.dict()) else: formatted_runs.append(dict(run)) return {"validation_runs": formatted_runs} def format_prediction_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]: """Format a list of prediction runs for API response.""" formatted_runs = [] for run in runs: if hasattr(run, 'model_dump'): formatted_runs.append(run.model_dump()) elif hasattr(run, 'dict'): formatted_runs.append(run.dict()) else: formatted_runs.append(dict(run)) return {"prediction_runs": formatted_runs} def format_single_model_response(model: Any) -> Dict[str, Any]: """Format a single model for API response.""" if not model: raise HTTPException(status_code=404, detail="Model not found") if hasattr(model, 'model_dump'): return model.model_dump() elif hasattr(model, 'dict'): return model.dict() else: return dict(model) def generate_model_id() -> str: """Generate a unique model ID.""" return str(uuid.uuid4()) def generate_alert_id() -> str: """Human-readable alert ID: ALT-YYYYMMDD-.""" date_part = datetime.now(timezone.utc).strftime("%Y%m%d") uid_part = uuid.uuid4().hex[:8] return f"ALT-{date_part}-{uid_part}" def safe_float_conversion(value: Any, default: float = 0.0) -> float: """Safely convert value to float with default fallback.""" try: if pd.isna(value): return default return float(value) except (ValueError, TypeError): return default def safe_int_conversion(value: Any, default: int = 0) -> int: """Safely convert value to int with default fallback.""" try: if pd.isna(value): return default return int(value) except (ValueError, TypeError): return default def calculate_progress_percentage(current_epoch: int, total_epochs: int) -> float: """Calculate progress percentage for training.""" if total_epochs <= 0: return 0.0 return min(100.0, (current_epoch / total_epochs) * 100.0) def estimate_completion_time(start_time: datetime, current_epoch: int, total_epochs: int) -> Optional[datetime]: """Estimate completion time based on current progress.""" if current_epoch <= 0 or total_epochs <= 0: return None elapsed_time = datetime.utcnow() - start_time time_per_epoch = elapsed_time / current_epoch remaining_epochs = total_epochs - current_epoch estimated_remaining_time = time_per_epoch * remaining_epochs return datetime.utcnow() + estimated_remaining_time def clean_model_response(model_data: Dict[str, Any]) -> Dict[str, Any]: """Clean model response by converting ObjectId to string and handling dates.""" cleaned_data = {} for key, value in model_data.items(): if key == "_id": cleaned_data["id"] = str(value) elif isinstance(value, datetime): cleaned_data[key] = value.isoformat() elif isinstance(value, list): # Handle list of runs cleaned_list = [] for item in value: if isinstance(item, dict): cleaned_item = clean_model_response(item) cleaned_list.append(cleaned_item) else: cleaned_list.append(item) cleaned_data[key] = cleaned_list else: cleaned_data[key] = value return cleaned_data def prepare_prediction_response(predictions: List[Dict[str, Any]], transaction_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]: """Prepare prediction response with proper formatting.""" formatted_predictions = [] for i, pred in enumerate(predictions): transaction_id = transaction_ids[i] if transaction_ids and i < len(transaction_ids) else f"transaction_{i}" formatted_pred = { "transaction_id": transaction_id, "predictions": pred } formatted_predictions.append(formatted_pred) return formatted_predictions def handle_missing_values(df: pd.DataFrame, fill_strategy: str = "unknown") -> pd.DataFrame: """Handle missing values in dataframe based on strategy.""" df_copy = df.copy() if fill_strategy == "unknown": df_copy = df_copy.fillna("Unknown") elif fill_strategy == "zero": df_copy = df_copy.fillna(0) elif fill_strategy == "drop": df_copy = df_copy.dropna() else: df_copy = df_copy.fillna("Unknown") return df_copy def validate_config_parameters(config: Dict[str, Any]) -> Dict[str, Any]: """Validate and sanitize configuration parameters.""" validated_config = {} # Validate batch_size batch_size = config.get("batch_size", 8) validated_config["batch_size"] = max(1, min(64, safe_int_conversion(batch_size, 8))) # Validate learning_rate learning_rate = config.get("learning_rate", 2e-5) validated_config["learning_rate"] = max(1e-6, min(1e-3, safe_float_conversion(learning_rate, 2e-5))) # Validate num_epochs num_epochs = config.get("num_epochs", 2) validated_config["num_epochs"] = max(1, min(20, safe_int_conversion(num_epochs, 2))) # Validate max_length max_length = config.get("max_length", 128) validated_config["max_length"] = max(32, min(512, safe_int_conversion(max_length, 128))) # Validate model_name model_name = config.get("model_name", "roberta-base") if not isinstance(model_name, str) or not model_name.strip(): validated_config["model_name"] = "roberta-base" else: validated_config["model_name"] = model_name.strip() # Validate random_state random_state = config.get("random_state", 42) validated_config["random_state"] = safe_int_conversion(random_state, 42) return validated_config def create_error_response(error_message: str, status_code: int = 500) -> HTTPException: """Create standardized error response.""" return HTTPException(status_code=status_code, detail=error_message) def generate_training_id() -> str: """Generate a unique training ID.""" return generate_operation_id() def generate_run_id() -> str: """Generate a unique run ID.""" return generate_operation_id() async def process_file_upload(file_bytes: bytes, filename: str, context: str) -> Tuple[str, str]: """Common file processing workflow for S3 upload and temp file creation.""" s3_key = s3_key_for_upload(filename) # Async upload to S3 await async_upload_file_to_s3(file_bytes, s3_key) print(f"[{context}] File uploaded to S3: {s3_key}") # Async write to temp file file_path = await async_write_temp_file(file_bytes, suffix='.csv') print(f"[{context}] File copied to temp for processing: {file_path}") return s3_key, file_path async def validate_model_access(user_id: str, model_name: ModelName, version: str, level: str = None) -> Any: """Validate that the version belongs to the given user_id.""" from app.services.train import TrainingService training_service = TrainingService() print(f"[DEBUG] validate_model_access is called with user_id={user_id}, model_name={model_name}, version={version}, level={level}") if version.lower() == "latest": return True if model_name == ModelName.ROBERTA_COMPLIANCE: training_run = await training_service.get_by_version(user_id, model_name, version,"multi_output") else: training_run = await training_service.get_by_version(user_id, model_name, version,level) if not training_run: raise HTTPException(status_code=404, detail=f"Training run with id '{version}' not found") if training_run.user_id != user_id: raise HTTPException(status_code=403, detail=f"User '{user_id}' does not have access to version '{version}'") print(f"[DEBUG] User '{user_id}' has access to version '{version}'") return training_run async def get_training_id_from_version(user_id: str, version: str, model_name: str, compliance_type: str, level=None) -> str: """ Resolve 'latest' tag to actual training_id for any model type. Args: user_id: User identifier training_id: Training ID (can be 'latest' or actual ID) model_name: Model name compliance_type: Compliance type Returns: str: Resolved training_id """ from app.services.train import TrainingService print(f"[LATEST_RESOLVE] Checking version: {version} for user_id: {user_id}, model_name: {model_name}, compliance_type: {compliance_type} , level: {level}") training_service = TrainingService() if version.lower() == "latest": if compliance_type == ComplianceType.NON_FIRCO: user_and_model = await async_user_and_model_crud.get_by_user_and_model(user_id, model_name, compliance_type, level="multi_output") else: user_and_model = await async_user_and_model_crud.get_by_user_and_model(user_id, model_name, compliance_type, level=level) if user_and_model: resolved_id = user_and_model.latest_version_training_id print(f"[LATEST_RESOLVE] Resolved 'latest' to training_id: {resolved_id} for {compliance_type}") return resolved_id else: print(f"[LATEST_RESOLVE] No trained {compliance_type} model found for user {user_id}") raise HTTPException(status_code=404, detail=f"No trained {compliance_type} model found for this user") else: print(f"[DEBUG] trying to get the version with parameters user_id={user_id}, model_name={model_name}, version={version}, level={level}") if compliance_type == ComplianceType.NON_FIRCO: training_run = await training_service.get_by_version(user_id, model_name, int(version), level="multi_output") else: training_run = await training_service.get_by_version(user_id, model_name, int(version), level) if not training_run: raise HTTPException(status_code=404, detail=f"Training run with version '{version}' not found") print(f"[LATEST_RESOLVE] Using provided version: {version}") return training_run.training_id async def get_validate_result_json_s3_keys( validation_id:str, compliance_type: ComplianceType, model_name: ModelName, classification_report: Dict[str, Any], training_id: str, user_id: str, filename: str, file_download_url: str, detailed_predictions: Any, ): # Store results JSON to S3 first to get S3 key results_s3_key = s3_key_for_upload(f"{validation_id}_results.json") print(f"[DEBUG] Generated S3 key: {results_s3_key}") try: frontend_model_name = "" # model_names is a simple list, not organized by compliance_type for dictonary in model_names: code = dictonary['code'] print(f"[DEBUG] Checking code: {code} against model_name: {model_name}") if code == model_name or code == str(model_name): frontend_model_name = dictonary['label'] print(f"[DEBUG] Found frontend model name: {frontend_model_name} for model_name: {model_name}") break if not frontend_model_name: print(f"[WARNING] No frontend model name found for {model_name}, using default") frontend_model_name = str(model_name) # Prepare temporary response for JSON storage temp_response = ValidationResponse( message="Validation completed successfully", metrics={"classification_report": classification_report}, training_id=training_id, user_id=user_id, model_name=frontend_model_name, compliance_type=compliance_type, # level=level, file=filename, dataset_download_url=file_download_url, results_download_url=f"https://{S3_BUCKET_NAME}.s3.{AWS_REGION}.amazonaws.com/{results_s3_key}" if results_s3_key else None, detailed_predictions=detailed_predictions ) # store the result in the json file and upload to s3 result_json_path = MODEL_SAVE_DIR / f"{validation_id}_results.json" with open(result_json_path, "w") as f: response_dict = temp_response.model_dump() json.dump(response_dict, f, indent=2) with open(result_json_path, "rb") as f: file_content = f.read() await async_upload_file_to_s3(file_content, results_s3_key) except Exception as e: print(f"[ERROR] Failed to save validation results to S3: {str(e)}") print(f"[ERROR] Error type: {type(e).__name__}") print(f"[ERROR] Full traceback:") traceback.print_exc() return results_s3_key async def get_prediction_result_json_s3_keys( prediction_id:str, prediction_results: List[Dict[str, Any]], file_download_url: str, filename: str, training_id: str, compliance_type: ComplianceType, model_name: ModelName, source_type:str, ): try: result_json_path = MODEL_SAVE_DIR / f"{prediction_id}_results.json" results_s3_key = s3_key_for_upload(f"{prediction_id}_results.json") frontend_model_name = "" print(f"[DEBUG] model_names structure: {model_names}") # model_names is a simple list, not organized by compliance_type for dictonary in model_names: code = dictonary['code'] print(f"[DEBUG] Checking code: {code} against model_name: {model_name}") if code == model_name or code == str(model_name): frontend_model_name = dictonary['label'] print(f"[DEBUG] Found frontend model name: {frontend_model_name} for model_name: {model_name}") break # Prepare response for JSON storage temp_response = BatchPredictionResponse( message="Prediction completed successfully", predictions=prediction_results, training_id=training_id, model_name=frontend_model_name, compliance_type=compliance_type, source_type=source_type, file=filename, dataset_download_url=file_download_url, results_download_url=f"https://{S3_BUCKET_NAME}.s3.{AWS_REGION}.amazonaws.com/{results_s3_key}" if results_s3_key else None ) with open(result_json_path, "w") as f: json.dump(temp_response.model_dump(), f) with open(result_json_path, "rb") as f: await async_upload_file_to_s3(f.read(), results_s3_key) print(f"[PREDICT] Results JSON uploaded to S3: {results_s3_key}") except Exception as e: print(f"Failed to save prediction results to S3: {str(e)}") return results_s3_key