prediqai / app /utils /utils.py
ganesh-vilje's picture
Deploy to Hugging Face Main
f8f02c0
import json
import traceback
from fastapi import HTTPException, UploadFile
from typing import Optional, List, Dict, Any, Tuple, Union
import uuid
from datetime import datetime, timezone
import pandas as pd
from app.schemas.prediction import BatchPredictionResponse
from app.schemas.validation import ValidationResponse
from app.services.auth import UserService
from app.core.enums import *
from app.core.constant import *
from app.utils.s3_utils import *
from app.services.models import async_model_crud
from app.utils.dataset_utils import convert_firco_json_to_dataframe, convert_non_firco_json_to_dataframe
import io
from app.core.enums import ModelName
from app.ml.models.firco_ensemble_model import FircoEnsembleXGBModel
from app.ml.models.firco_tfidf_model import FircoTFIDFXGBModel
from app.ml.models.non_firco_roberta_model import NonFircoRobertaModel
from app.services.user_and_model import async_user_and_model_crud
user_service = UserService()
def validate_level(level: str):
"""Validate that the level is one of the allowed values."""
if level not in ("hit", "message", "both"):
raise HTTPException(status_code=400, detail="level must be 'hit', 'message', or 'both'")
def format_prediction_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]:
"""Format a list of prediction runs for API response."""
return {"prediction_runs": [run.model_dump() for run in runs]}
def format_single_run_response(run: Any, entity_name: str) -> Dict[str, Any]:
"""Format a single run (training/validation/prediction) for API response."""
if not run:
raise HTTPException(status_code=404, detail=f"{entity_name} not found")
return run.model_dump()
## To be changed later (should use version instead of training)
async def get_version_from_model_id(model_id: str, async_model_crud) -> str:
from app.services.train import TrainingService
training_service = TrainingService()
"""Helper function to get training_id from model_id."""
model_record = await async_model_crud.get_by_model_id(model_id)
if not model_record:
raise HTTPException(status_code=404, detail=f"Model with model_id '{model_id}' not found")
training_runs = await training_service.list_by_model(model_id)
if not training_runs:
raise HTTPException(status_code=404, detail=f"No training runs found for model_id '{model_id}'")
return training_runs[0].training_id # Latest training run
def get_model_files_for_level(level: str, training_id: str) -> List[tuple]:
"""Generate list of (s3_filename, zip_filename) for a given level."""
return [
(f"xgb_model_{level}_{training_id}.pkl", f"{level}_model.pkl"),
(f"feature_encoder_{level}_{training_id}.pkl", f"{level}_feature_encoder.pkl"),
(f"label_encoder_{level}_{training_id}.pkl", f"{level}_label_encoder.pkl")
]
def get_download_files_and_name(level: str, training_id: str, model_id: str) -> tuple:
"""Get files to download and zip filename based on level."""
if level == "both":
files_to_download = (get_model_files_for_level("hit", training_id) +
get_model_files_for_level("message", training_id))
zip_filename = f"model_{model_id}_both_{training_id}.zip"
else:
files_to_download = get_model_files_for_level(level, training_id)
zip_filename = f"model_{model_id}_{level}_{training_id}.zip"
return files_to_download, zip_filename
# Additional helper functions for model_utils.py duplication reduction
def validate_model_name_from_string(model_name_str: str):
"""Convert and validate model type string to enum."""
try:
return ModelName(model_name_str)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid model_name: {model_name_str}. Supported types: {[e.value for e in ModelName]}"
)
def validate_dataframe_not_empty(df, context: str = "data"):
"""Validate that dataframe is not empty."""
if df.empty or len(df) == 0:
raise HTTPException(
status_code=400,
detail=f"The {context} is empty or contains no valid data after processing"
)
def validate_prediction_csv_params(file: UploadFile, level: str, training_id: str):
"""Validate all CSV prediction parameters at once."""
validate_required_param(training_id, "training_id")
validate_level(level)
validate_required_param(file, "file")
validate_csv_file(file)
def validate_prediction_json_params(training_id: str, level: str):
"""Validate all JSON prediction parameters at once."""
validate_required_param(training_id, "training_id")
validate_level(level)
def generate_operation_id() -> str:
"""Generate a unique ID for operations (training, validation, prediction)."""
return datetime.now().strftime("%Y%m%d_%H%M%S") + '_' + str(uuid.uuid4())[:12]
async def create_database_record(crud_instance, record_data, record_type: str):
"""Generic database record creation with error handling."""
try:
record = await crud_instance.create(record_data)
print(f"[DB] {record_type} record created: {record.id}")
return record
except Exception as e:
print(f"[DB] Error creating {record_type} record: {str(e)}")
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
# TO DO apply logic to get ensemble
def get_model_instances(training_id=None, model_name = None, targets = None) -> Dict[str, Any]:
"""Get model instances based on level (hit, message, or both) and model type."""
if model_name == ModelName.FIRCO_ENSEMBLE_XGB:
ModelClass = FircoEnsembleXGBModel
elif model_name == ModelName.FIRCO_TFIDF_XGB:
ModelClass = FircoTFIDFXGBModel
elif model_name == ModelName.ROBERTA_COMPLIANCE:
ModelClass = NonFircoRobertaModel
# print("[DEBUG] ModelClass selected is ", ModelClass)
return ModelClass(training_id=training_id, targets=targets)
def combine_both_level_results(hit_results: Dict, message_results: Dict, operation: str) -> Dict:
"""Combine results from both hit and message level operations."""
if operation == "validation":
combined_accuracy = (hit_results["accuracy"] + message_results["accuracy"]) / 2
return {
"accuracy": combined_accuracy,
"metrics": {
"hit_level": hit_results["classification_report"],
"message_level": message_results["classification_report"]
},
"data_size": {"total_records": hit_results["total_samples"]},
"hit_results": hit_results,
"message_results": message_results
}
elif operation == "training":
print(f"[DEBUG] Combining training results for both levels resultant artifacts is { {**hit_results['artifacts']['s3_keys'], **message_results['artifacts']['s3_keys']} }")
return {
"data_size": hit_results["data_size"],
"artifacts": {
"s3_keys": {
"hit": hit_results["artifacts"]["s3_keys"],
"message": message_results["artifacts"]["s3_keys"]
}
},
"training_time": hit_results["training_time"] + message_results["training_time"]
}
def validate_model_name_support(model_name, operation: str = "operation"):
"""Validate that model name is supported for the operation."""
# Currently supported model names for this API
supported_names = list(ModelName.__members__.values())
if model_name not in supported_names:
raise ValueError(
f"Model name {model_name} not yet implemented for {operation}. "
f"Currently supported: {[t.value for t in supported_names]}"
)
async def process_json_to_dataframe(data, compliance_type: ComplianceType, debug_context: str = "PREDICT") -> Tuple[Any, str, str]:
"""Process JSON data to DataFrame and handle S3 upload."""
if compliance_type == ComplianceType.FIRCO:
# Convert FIRCO JSON to DataFrame using dataset_utils
df = convert_firco_json_to_dataframe(data)
print(f"[{debug_context}] Firco JSON data converted to DataFrame. Shape: {df.shape}")
else:
print("[DEBUG] Non-Firco JSON data received for conversion to DataFrame")
df = convert_non_firco_json_to_dataframe(data)
print(f"[{debug_context}] Non-Firco JSON data converted to DataFrame. Shape: {df.shape}")
# Check if dataframe is empty
validate_dataframe_not_empty(df, "JSON data")
# Save debug DataFrame to S3 instead of local
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
debug_filename = f"debug_json_data_{timestamp}.csv"
# Convert DataFrame to CSV bytes
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False)
csv_bytes = csv_buffer.getvalue().encode('utf-8')
# Async upload debug CSV to S3
s3_key = s3_key_for_upload(debug_filename)
await async_upload_file_to_s3(csv_bytes, s3_key)
# Async write to temp file for processing
file_path = await async_write_temp_file(csv_bytes, suffix='.csv')
print(f"[{debug_context}] Debug DataFrame uploaded to S3: {debug_filename}")
print(f"[{debug_context}] DataFrame columns: {df.columns.tolist()}")
return df, file_path, s3_key
def validate_csv_file(file: UploadFile):
"""Validate that the uploaded file is a CSV file."""
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
def validate_required_param(param_value: Any, param_name: str):
"""Validate that a required parameter is provided."""
if not param_value:
raise HTTPException(status_code=400, detail=f"{param_name} is required")
def validate_user_id(user_id: str):
"""Validate user ID format."""
if not user_id or not isinstance(user_id, str) or len(user_id.strip()) == 0:
raise HTTPException(status_code=400, detail="Valid user_id is required")
def validate_version(version: str):
"""Validate version format."""
if not version or not isinstance(version, str) or len(version.strip()) == 0:
raise HTTPException(status_code=400, detail="Valid version is required")
def validate_model_id(model_id: str):
"""Validate model ID format."""
if not model_id or not isinstance(model_id, str) or len(model_id.strip()) == 0:
raise HTTPException(status_code=400, detail="Valid model_id is required")
def validate_required_columns(df: pd.DataFrame, required_columns: List[str], context: str = "dataset"):
"""Validate that dataframe contains all required columns."""
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise HTTPException(
status_code=400,
detail=f"Missing required columns in {context}: {missing_columns}"
)
def format_model_list_response(models: List[Any]) -> Dict[str, List[Dict]]:
"""Format a list of models for API response."""
formatted_models = []
for model in models:
if hasattr(model, 'model_dump'):
formatted_models.append(model.model_dump())
elif hasattr(model, 'dict'):
formatted_models.append(model.dict())
else:
formatted_models.append(dict(model))
return {"models": formatted_models}
def format_training_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]:
"""Format a list of training runs for API response."""
formatted_runs = []
for run in runs:
if hasattr(run, 'model_dump'):
formatted_runs.append(run.model_dump())
elif hasattr(run, 'dict'):
formatted_runs.append(run.dict())
else:
formatted_runs.append(dict(run))
return {"training_runs": formatted_runs}
async def format_enhanced_training_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]:
"""
Format a list of training runs for API response with enhanced model information and S3 URLs.
"""
formatted_runs = []
for run in runs:
# Convert run to dict
if hasattr(run, 'model_dump'):
run_dict = run.model_dump()
elif hasattr(run, 'dict'):
run_dict = run.dict()
else:
run_dict = dict(run)
# Get model information
model_info = None
if run_dict.get('model_id'):
try:
model_info = await async_model_crud.get_by_model_id(run_dict['model_id'])
except Exception as e:
print(f"[WARNING] Failed to get model info for model_id {run_dict.get('model_id')}: {e}")
# Add model_type and model_version from model info
if model_info:
run_dict['model_type'] = model_info.model_name
run_dict['model_version'] = model_info.model_version
# Enhance datasets with S3 URLs
if 'datasets' in run_dict and run_dict['datasets']:
datasets = run_dict['datasets']
# Add S3 URLs for data files
if datasets.get('train'):
datasets['train_url'] = get_s3_url(datasets['train'])
if datasets.get('validation'):
datasets['validation_url'] = get_s3_url(datasets['validation'])
if datasets.get('test'):
datasets['test_url'] = get_s3_url(datasets['test'])
if datasets.get('s3_key'):
datasets['data_file_url'] = get_s3_url(datasets['s3_key'])
formatted_runs.append(run_dict)
return {"training_runs": formatted_runs}
async def format_universal_runs_response(runs: Union[Any, List[Any]], entity_type: str = "run") -> Union[Dict[str, Any], Dict[str, List[Dict]]]:
"""
Universal format function for single or multiple runs (training/validation/prediction) with enhanced model information and S3 URLs.
Args:
runs: Single run object or list of run objects
entity_type: Type of entity ("training", "validation", "prediction", or "run")
return_single: If True, returns single run dict; if False, returns {"<entity_type>_runs": [runs]}
Returns:
Dict: Either single run dict or {"<entity_type>_runs": [formatted_runs]}
"""
async def format_single_run(run):
"""Helper function to format a single run"""
if not run:
return None
if hasattr(run, 'model_dump'):
run_dict = run.model_dump()
elif hasattr(run, 'dict'):
run_dict = run.dict()
else:
run_dict = dict(run)
model_info = None
if run_dict.get('model_id'):
try:
model_info = await async_model_crud.get_by_model_id(run_dict['model_id'])
except Exception as e:
print(f"[WARNING] Failed to get model info for model_id {run_dict.get('model_id')}: {e}")
if model_info:
run_dict['model_type'] = next(model_name["label"] for model_name in model_names
if model_name["code"] == model_info.model_name)
run_dict['model_version'] = model_info.model_version
# Transform based on entity type to match desired schema
if entity_type == "training":
return format_training_run_schema(run_dict)
elif entity_type == "validation":
return format_validation_run_schema(run_dict)
elif entity_type == "prediction":
return format_prediction_run_schema(run_dict)
else:
# Default behavior for backward compatibility
if 'datasets' in run_dict and run_dict['datasets']:
datasets = run_dict['datasets']
# Add S3 URLs for data files
if datasets.get('train'):
datasets['train_url'] = get_s3_url(datasets['train'])
if datasets.get('validation'):
datasets['validation_url'] = get_s3_url(datasets['validation'])
if datasets.get('test'):
datasets['test_url'] = get_s3_url(datasets['test'])
if datasets.get('s3_key'):
datasets['data_file_url'] = get_s3_url(datasets['s3_key'])
return run_dict
# Handle single run case
if not isinstance(runs, list):
formatted_run = await format_single_run(runs)
if formatted_run is None:
raise HTTPException(status_code=404, detail=f"{entity_type.capitalize()} run not found")
# Return as a list with single item
return {f"{entity_type}_runs": [formatted_run]}
# Handle multiple runs case
formatted_runs = []
for run in runs:
formatted_run = await format_single_run(run)
if formatted_run:
formatted_runs.append(formatted_run)
return {f"{entity_type}_runs": formatted_runs}
def format_training_run_schema(run_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Transform training run to match desired schema"""
# Extract original fields
datasets = run_dict.get('datasets') or {}
data_size = run_dict.get('data_size') or {}
training_config = run_dict.get('training_config') or {}
if 'bert' in data_size:
data_size = data_size['bert']
print("[DEBUG] Adjusted data_size for bert model: ", data_size)
# Create transformed schema
transformed = {
"training_id": run_dict.get('training_id'),
"user_name": run_dict.get('user_id'), # Map user_id to user_name
"model_version": run_dict.get('model_version'),
"model_type": run_dict.get('model_type'),
"compliance_type": run_dict.get('compliance_type'),
"status": run_dict.get('status'),
"created_at": run_dict.get('created_at'),
"datasets": {
"data_size": {
"total_records": data_size.get('total_records')
},
"training_config": {
"test_size": training_config.get('test_size', 0.2),
"random_state": training_config.get('random_state', 42),
"stratify": training_config.get('stratify', True),
"class_weights": training_config.get('class_weights',None)
},
"train_data": {
"data_source": "aws_s3",
"data_file": get_s3_url(datasets.get('s3_key')) or get_s3_url(datasets.get('train'))
}
}
}
return transformed
def format_validation_run_schema(run_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Transform validation run to match desired schema"""
datasets = run_dict.get('datasets') or {}
data_size = run_dict.get('data_size') or {}
metrics = run_dict.get('metrics') or {}
accuracy_file_level = None
accuracies = []
classification_report = metrics.get("classification_report") or {}
if run_dict.get('model_name') == "prime":
direct_accuracy = classification_report.get("accuracy") or None
if direct_accuracy is not None and isinstance(direct_accuracy, (int, float)):
accuracy_file_level = int(direct_accuracy * 100)
else:
for label, label_metrics in classification_report.items():
if isinstance(label_metrics, dict):
acc = label_metrics.get("accuracy")
if isinstance(acc, (int, float)):
accuracies.append(acc)
if accuracies:
avg_accuracy = sum(accuracies) / len(accuracies)
accuracy_file_level = int(avg_accuracy * 100)
transformed = {
"validation_id": run_dict.get('validation_id'),
"user_name": run_dict.get('user_id'), # Map user_id to user_name
"model_version": run_dict.get('model_version'),
"model_type": run_dict.get('model_type'),
"compliance_type": run_dict.get('compliance_type'),
"status": run_dict.get('status'),
"accuracy_file_level": accuracy_file_level,
"created_at": run_dict.get('created_at'),
"datasets": {
"data_size": {
"total_records": data_size.get('total_records'),
},
"validate_data_input": {
"data_source": "aws_s3",
"data_file": get_s3_url(datasets.get('s3_key')) or get_s3_url(datasets.get('validation'))
},
"validate_data_output": {
"data_source": "aws_s3",
"data_file": get_s3_url(run_dict.get('results_s3_key')) # Use results S3 key for output JSON
}
}
}
# print("Transformed is, ",transformed)
return transformed
def format_prediction_run_schema(run_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Transform prediction run to match desired schema"""
# Extract original fields
input_data = run_dict.get('input_data') or {}
predictions = run_dict.get('predictions') or {}
# Calculate accuracy from predictions if available
accuracy_file_level = None
if predictions and isinstance(predictions, dict):
# Get the actual predictions list from the dict
predictions_list = predictions.get('predictions', [])
if predictions_list and isinstance(predictions_list, list):
# Calculate average accuracy from predictions
total_accuracy = 0
count = 0
for pred in predictions_list:
if isinstance(pred, dict):
# Try different accuracy field names
accuracy_value = pred.get('accuracy') or pred.get('probability')
if accuracy_value is not None:
# If it's a probability (0-1), convert to percentage
if accuracy_value <= 1.0:
accuracy_value = accuracy_value * 100
total_accuracy += accuracy_value
count += 1
if count > 0:
accuracy_file_level = int(total_accuracy / count)
# Create transformed schema
transformed = {
"prediction_id": run_dict.get('prediction_id'),
"training_id": run_dict.get('training_id'),
"user_name": run_dict.get('user_id'), # Map user_id to user_name
"model_version": run_dict.get('model_version'),
"compliance_type": run_dict.get('compliance_type'),
"model_type": run_dict.get('model_type'),
"status": run_dict.get('status'),
"accuracy_file_level": accuracy_file_level,
"created_at": run_dict.get('created_at'),
"datasets": {
"total_records": predictions.get('total_samples') or predictions.get('total_predictions') or input_data.get('record_count'),
"predict_data_input": {
"data_source": "aws_s3",
"data_file": get_s3_url(input_data.get('s3_key'))
},
"predict_data_output": {
"data_source": "aws_s3",
"data_file": get_s3_url(predictions.get('results_s3_key')) or get_s3_url(run_dict.get('results_s3_key')) # Try predictions dict first, then run dict
}
}
}
return transformed
def format_validation_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]:
"""Format a list of validation runs for API response."""
formatted_runs = []
for run in runs:
if hasattr(run, 'model_dump'):
formatted_runs.append(run.model_dump())
elif hasattr(run, 'dict'):
formatted_runs.append(run.dict())
else:
formatted_runs.append(dict(run))
return {"validation_runs": formatted_runs}
def format_prediction_runs_response(runs: List[Any]) -> Dict[str, List[Dict]]:
"""Format a list of prediction runs for API response."""
formatted_runs = []
for run in runs:
if hasattr(run, 'model_dump'):
formatted_runs.append(run.model_dump())
elif hasattr(run, 'dict'):
formatted_runs.append(run.dict())
else:
formatted_runs.append(dict(run))
return {"prediction_runs": formatted_runs}
def format_single_model_response(model: Any) -> Dict[str, Any]:
"""Format a single model for API response."""
if not model:
raise HTTPException(status_code=404, detail="Model not found")
if hasattr(model, 'model_dump'):
return model.model_dump()
elif hasattr(model, 'dict'):
return model.dict()
else:
return dict(model)
def generate_model_id() -> str:
"""Generate a unique model ID."""
return str(uuid.uuid4())
def generate_alert_id() -> str:
"""Human-readable alert ID: ALT-YYYYMMDD-<short_uuid>."""
date_part = datetime.now(timezone.utc).strftime("%Y%m%d")
uid_part = uuid.uuid4().hex[:8]
return f"ALT-{date_part}-{uid_part}"
def safe_float_conversion(value: Any, default: float = 0.0) -> float:
"""Safely convert value to float with default fallback."""
try:
if pd.isna(value):
return default
return float(value)
except (ValueError, TypeError):
return default
def safe_int_conversion(value: Any, default: int = 0) -> int:
"""Safely convert value to int with default fallback."""
try:
if pd.isna(value):
return default
return int(value)
except (ValueError, TypeError):
return default
def calculate_progress_percentage(current_epoch: int, total_epochs: int) -> float:
"""Calculate progress percentage for training."""
if total_epochs <= 0:
return 0.0
return min(100.0, (current_epoch / total_epochs) * 100.0)
def estimate_completion_time(start_time: datetime, current_epoch: int, total_epochs: int) -> Optional[datetime]:
"""Estimate completion time based on current progress."""
if current_epoch <= 0 or total_epochs <= 0:
return None
elapsed_time = datetime.utcnow() - start_time
time_per_epoch = elapsed_time / current_epoch
remaining_epochs = total_epochs - current_epoch
estimated_remaining_time = time_per_epoch * remaining_epochs
return datetime.utcnow() + estimated_remaining_time
def clean_model_response(model_data: Dict[str, Any]) -> Dict[str, Any]:
"""Clean model response by converting ObjectId to string and handling dates."""
cleaned_data = {}
for key, value in model_data.items():
if key == "_id":
cleaned_data["id"] = str(value)
elif isinstance(value, datetime):
cleaned_data[key] = value.isoformat()
elif isinstance(value, list):
# Handle list of runs
cleaned_list = []
for item in value:
if isinstance(item, dict):
cleaned_item = clean_model_response(item)
cleaned_list.append(cleaned_item)
else:
cleaned_list.append(item)
cleaned_data[key] = cleaned_list
else:
cleaned_data[key] = value
return cleaned_data
def prepare_prediction_response(predictions: List[Dict[str, Any]],
transaction_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""Prepare prediction response with proper formatting."""
formatted_predictions = []
for i, pred in enumerate(predictions):
transaction_id = transaction_ids[i] if transaction_ids and i < len(transaction_ids) else f"transaction_{i}"
formatted_pred = {
"transaction_id": transaction_id,
"predictions": pred
}
formatted_predictions.append(formatted_pred)
return formatted_predictions
def handle_missing_values(df: pd.DataFrame, fill_strategy: str = "unknown") -> pd.DataFrame:
"""Handle missing values in dataframe based on strategy."""
df_copy = df.copy()
if fill_strategy == "unknown":
df_copy = df_copy.fillna("Unknown")
elif fill_strategy == "zero":
df_copy = df_copy.fillna(0)
elif fill_strategy == "drop":
df_copy = df_copy.dropna()
else:
df_copy = df_copy.fillna("Unknown")
return df_copy
def validate_config_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
"""Validate and sanitize configuration parameters."""
validated_config = {}
# Validate batch_size
batch_size = config.get("batch_size", 8)
validated_config["batch_size"] = max(1, min(64, safe_int_conversion(batch_size, 8)))
# Validate learning_rate
learning_rate = config.get("learning_rate", 2e-5)
validated_config["learning_rate"] = max(1e-6, min(1e-3, safe_float_conversion(learning_rate, 2e-5)))
# Validate num_epochs
num_epochs = config.get("num_epochs", 2)
validated_config["num_epochs"] = max(1, min(20, safe_int_conversion(num_epochs, 2)))
# Validate max_length
max_length = config.get("max_length", 128)
validated_config["max_length"] = max(32, min(512, safe_int_conversion(max_length, 128)))
# Validate model_name
model_name = config.get("model_name", "roberta-base")
if not isinstance(model_name, str) or not model_name.strip():
validated_config["model_name"] = "roberta-base"
else:
validated_config["model_name"] = model_name.strip()
# Validate random_state
random_state = config.get("random_state", 42)
validated_config["random_state"] = safe_int_conversion(random_state, 42)
return validated_config
def create_error_response(error_message: str, status_code: int = 500) -> HTTPException:
"""Create standardized error response."""
return HTTPException(status_code=status_code, detail=error_message)
def generate_training_id() -> str:
"""Generate a unique training ID."""
return generate_operation_id()
def generate_run_id() -> str:
"""Generate a unique run ID."""
return generate_operation_id()
async def process_file_upload(file_bytes: bytes, filename: str, context: str) -> Tuple[str, str]:
"""Common file processing workflow for S3 upload and temp file creation."""
s3_key = s3_key_for_upload(filename)
# Async upload to S3
await async_upload_file_to_s3(file_bytes, s3_key)
print(f"[{context}] File uploaded to S3: {s3_key}")
# Async write to temp file
file_path = await async_write_temp_file(file_bytes, suffix='.csv')
print(f"[{context}] File copied to temp for processing: {file_path}")
return s3_key, file_path
async def validate_model_access(user_id: str, model_name: ModelName, version: str, level: str = None) -> Any:
"""Validate that the version belongs to the given user_id."""
from app.services.train import TrainingService
training_service = TrainingService()
print(f"[DEBUG] validate_model_access is called with user_id={user_id}, model_name={model_name}, version={version}, level={level}")
if version.lower() == "latest":
return True
if model_name == ModelName.ROBERTA_COMPLIANCE:
training_run = await training_service.get_by_version(user_id, model_name, version,"multi_output")
else:
training_run = await training_service.get_by_version(user_id, model_name, version,level)
if not training_run:
raise HTTPException(status_code=404, detail=f"Training run with id '{version}' not found")
if training_run.user_id != user_id:
raise HTTPException(status_code=403, detail=f"User '{user_id}' does not have access to version '{version}'")
print(f"[DEBUG] User '{user_id}' has access to version '{version}'")
return training_run
async def get_training_id_from_version(user_id: str, version: str, model_name: str, compliance_type: str, level=None) -> str:
"""
Resolve 'latest' tag to actual training_id for any model type.
Args:
user_id: User identifier
training_id: Training ID (can be 'latest' or actual ID)
model_name: Model name
compliance_type: Compliance type
Returns:
str: Resolved training_id
"""
from app.services.train import TrainingService
print(f"[LATEST_RESOLVE] Checking version: {version} for user_id: {user_id}, model_name: {model_name}, compliance_type: {compliance_type} , level: {level}")
training_service = TrainingService()
if version.lower() == "latest":
if compliance_type == ComplianceType.NON_FIRCO:
user_and_model = await async_user_and_model_crud.get_by_user_and_model(user_id, model_name, compliance_type, level="multi_output")
else:
user_and_model = await async_user_and_model_crud.get_by_user_and_model(user_id, model_name, compliance_type, level=level)
if user_and_model:
resolved_id = user_and_model.latest_version_training_id
print(f"[LATEST_RESOLVE] Resolved 'latest' to training_id: {resolved_id} for {compliance_type}")
return resolved_id
else:
print(f"[LATEST_RESOLVE] No trained {compliance_type} model found for user {user_id}")
raise HTTPException(status_code=404, detail=f"No trained {compliance_type} model found for this user")
else:
print(f"[DEBUG] trying to get the version with parameters user_id={user_id}, model_name={model_name}, version={version}, level={level}")
if compliance_type == ComplianceType.NON_FIRCO:
training_run = await training_service.get_by_version(user_id, model_name, int(version), level="multi_output")
else:
training_run = await training_service.get_by_version(user_id, model_name, int(version), level)
if not training_run:
raise HTTPException(status_code=404, detail=f"Training run with version '{version}' not found")
print(f"[LATEST_RESOLVE] Using provided version: {version}")
return training_run.training_id
async def get_validate_result_json_s3_keys(
validation_id:str,
compliance_type: ComplianceType,
model_name: ModelName,
classification_report: Dict[str, Any],
training_id: str,
user_id: str,
filename: str,
file_download_url: str,
detailed_predictions: Any,
):
# Store results JSON to S3 first to get S3 key
results_s3_key = s3_key_for_upload(f"{validation_id}_results.json")
print(f"[DEBUG] Generated S3 key: {results_s3_key}")
try:
frontend_model_name = ""
# model_names is a simple list, not organized by compliance_type
for dictonary in model_names:
code = dictonary['code']
print(f"[DEBUG] Checking code: {code} against model_name: {model_name}")
if code == model_name or code == str(model_name):
frontend_model_name = dictonary['label']
print(f"[DEBUG] Found frontend model name: {frontend_model_name} for model_name: {model_name}")
break
if not frontend_model_name:
print(f"[WARNING] No frontend model name found for {model_name}, using default")
frontend_model_name = str(model_name)
# Prepare temporary response for JSON storage
temp_response = ValidationResponse(
message="Validation completed successfully",
metrics={"classification_report": classification_report},
training_id=training_id,
user_id=user_id,
model_name=frontend_model_name,
compliance_type=compliance_type,
# level=level,
file=filename,
dataset_download_url=file_download_url,
results_download_url=f"https://{S3_BUCKET_NAME}.s3.{AWS_REGION}.amazonaws.com/{results_s3_key}" if results_s3_key else None,
detailed_predictions=detailed_predictions
)
# store the result in the json file and upload to s3
result_json_path = MODEL_SAVE_DIR / f"{validation_id}_results.json"
with open(result_json_path, "w") as f:
response_dict = temp_response.model_dump()
json.dump(response_dict, f, indent=2)
with open(result_json_path, "rb") as f:
file_content = f.read()
await async_upload_file_to_s3(file_content, results_s3_key)
except Exception as e:
print(f"[ERROR] Failed to save validation results to S3: {str(e)}")
print(f"[ERROR] Error type: {type(e).__name__}")
print(f"[ERROR] Full traceback:")
traceback.print_exc()
return results_s3_key
async def get_prediction_result_json_s3_keys(
prediction_id:str,
prediction_results: List[Dict[str, Any]],
file_download_url: str,
filename: str,
training_id: str,
compliance_type: ComplianceType,
model_name: ModelName,
source_type:str,
):
try:
result_json_path = MODEL_SAVE_DIR / f"{prediction_id}_results.json"
results_s3_key = s3_key_for_upload(f"{prediction_id}_results.json")
frontend_model_name = ""
print(f"[DEBUG] model_names structure: {model_names}")
# model_names is a simple list, not organized by compliance_type
for dictonary in model_names:
code = dictonary['code']
print(f"[DEBUG] Checking code: {code} against model_name: {model_name}")
if code == model_name or code == str(model_name):
frontend_model_name = dictonary['label']
print(f"[DEBUG] Found frontend model name: {frontend_model_name} for model_name: {model_name}")
break
# Prepare response for JSON storage
temp_response = BatchPredictionResponse(
message="Prediction completed successfully",
predictions=prediction_results,
training_id=training_id,
model_name=frontend_model_name,
compliance_type=compliance_type,
source_type=source_type,
file=filename,
dataset_download_url=file_download_url,
results_download_url=f"https://{S3_BUCKET_NAME}.s3.{AWS_REGION}.amazonaws.com/{results_s3_key}" if results_s3_key else None
)
with open(result_json_path, "w") as f:
json.dump(temp_response.model_dump(), f)
with open(result_json_path, "rb") as f:
await async_upload_file_to_s3(f.read(), results_s3_key)
print(f"[PREDICT] Results JSON uploaded to S3: {results_s3_key}")
except Exception as e:
print(f"Failed to save prediction results to S3: {str(e)}")
return results_s3_key