Spaces:
Sleeping
Sleeping
| # src/api.py - Enhanced API with better error handling for patient data | |
| from fastapi import FastAPI, HTTPException, File, UploadFile | |
| from pydantic import BaseModel, Field | |
| import torch | |
| import numpy as np | |
| import joblib | |
| from src.model import TabularVAE | |
| from typing import List, Optional, Dict, Any | |
| import os | |
| import shutil | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| import json | |
| app = FastAPI(title="Healthcare VAE API", version="1.0.0") | |
| # Load model and scaler | |
| try: | |
| # Load feature names and determine input dimension | |
| if os.path.exists("models/feature_names.pkl"): | |
| feature_names = joblib.load("models/feature_names.pkl") | |
| INPUT_DIM = len(feature_names) | |
| print(f"Loaded {INPUT_DIM} features: {feature_names}") | |
| else: | |
| # Fallback to default features | |
| feature_names = ["age", "gender", "diagnosis", "blood_type", "length_of_stay", | |
| "age_group", "admission_season", "admission_day", "admission_month", "admission_year"] | |
| INPUT_DIM = len(feature_names) | |
| print(f"Using default {INPUT_DIM} features") | |
| LATENT_DIM = 8 | |
| model = TabularVAE(input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden_dims=(32, 16)) | |
| model.load_state_dict(torch.load("models/vae_model.pth", map_location='cpu')) | |
| model.eval() | |
| scaler = joblib.load("models/scaler.pkl") | |
| # Load encoders if available | |
| encoders = None | |
| if os.path.exists("models/encoders.pkl"): | |
| encoders = joblib.load("models/encoders.pkl") | |
| print("Model and scaler loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| print("Please run training first!") | |
| class GenerateRequest(BaseModel): | |
| n_samples: int = Field(..., ge=1, le=1000, description="Number of samples to generate") | |
| random_seed: Optional[int] = Field(None, description="Random seed for reproducibility") | |
| temperature: float = Field(1.0, ge=0.1, le=2.0, description="Sampling temperature") | |
| class PatientData(BaseModel): | |
| age: float = Field(..., ge=0, le=120, description="Patient age") | |
| gender: str = Field(..., description="Patient gender (Male/Female)") | |
| diagnosis: str = Field(..., description="Patient diagnosis") | |
| blood_type: str = Field(..., description="Blood type") | |
| length_of_stay: Optional[float] = Field(None, description="Length of stay in days") | |
| age_group: Optional[int] = Field(None, ge=0, le=4, description="Age group (0-4)") | |
| admission_season: Optional[int] = Field(None, ge=0, le=3, description="Admission season (0-3)") | |
| admission_day: Optional[int] = Field(None, ge=0, le=6, description="Admission day of week (0-6)") | |
| admission_month: Optional[int] = Field(None, ge=0, le=11, description="Admission month (0-11)") | |
| admission_year: Optional[int] = Field(None, description="Admission year (normalized)") | |
| class GeneratedResponse(BaseModel): | |
| data: List[List[float]] | |
| metadata: dict | |
| def convert_numpy_to_python(obj): | |
| """Convert numpy types to Python native types for JSON serialization""" | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| elif isinstance(obj, np.floating): | |
| return float(obj) | |
| elif isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| elif isinstance(obj, list): | |
| return [convert_numpy_to_python(item) for item in obj] | |
| elif isinstance(obj, dict): | |
| return {key: convert_numpy_to_python(value) for key, value in obj.items()} | |
| else: | |
| return obj | |
| def read_root(): | |
| return {"message": "Healthcare VAE API is running!", "features": feature_names} | |
| def get_features(): | |
| """Get information about the model features""" | |
| return { | |
| "feature_names": feature_names, | |
| "input_dim": INPUT_DIM, | |
| "latent_dim": LATENT_DIM | |
| } | |
| def generate_synthetic_data(request: GenerateRequest): | |
| try: | |
| if request.random_seed is not None: | |
| torch.manual_seed(request.random_seed) | |
| np.random.seed(request.random_seed) | |
| # Generate samples | |
| z = torch.randn(request.n_samples, LATENT_DIM) * request.temperature | |
| with torch.no_grad(): | |
| samples = model.decode(z).numpy() | |
| # Inverse transform to original scale | |
| data = scaler.inverse_transform(samples).tolist() | |
| metadata = { | |
| "n_samples": request.n_samples, | |
| "latent_dim": LATENT_DIM, | |
| "temperature": request.temperature, | |
| "features": feature_names | |
| } | |
| return {"data": data, "metadata": metadata} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| def encode_patient(patient: PatientData): | |
| """Encode patient data to latent space""" | |
| try: | |
| # Convert patient data to feature vector | |
| feature_vector = [] | |
| # Age | |
| feature_vector.append(patient.age) | |
| # Gender (encode if encoders available) | |
| if encoders and 'gender' in encoders: | |
| gender_encoded = encoders['gender'].transform([patient.gender])[0] | |
| feature_vector.append(gender_encoded) | |
| else: | |
| # Fallback encoding | |
| gender_encoded = 0 if patient.gender.lower() == 'male' else 1 | |
| feature_vector.append(gender_encoded) | |
| # Diagnosis (encode if encoders available) | |
| if encoders and 'diagnosis' in encoders: | |
| diagnosis_encoded = encoders['diagnosis'].transform([patient.diagnosis])[0] | |
| feature_vector.append(diagnosis_encoded) | |
| else: | |
| # Fallback encoding (simple hash) | |
| diagnosis_encoded = hash(patient.diagnosis) % 10 | |
| feature_vector.append(diagnosis_encoded) | |
| # Blood type (encode if encoders available) | |
| if encoders and 'blood_type' in encoders: | |
| blood_encoded = encoders['blood_type'].transform([patient.blood_type])[0] | |
| feature_vector.append(blood_encoded) | |
| else: | |
| # Fallback encoding (simple hash) | |
| blood_encoded = hash(patient.blood_type) % 8 | |
| feature_vector.append(blood_encoded) | |
| # Length of stay | |
| los = patient.length_of_stay if patient.length_of_stay is not None else 7.0 | |
| feature_vector.append(los) | |
| # Age group | |
| age_group = patient.age_group if patient.age_group is not None else 2 | |
| feature_vector.append(age_group) | |
| # Admission season | |
| season = patient.admission_season if patient.admission_season is not None else 0 | |
| feature_vector.append(season) | |
| # Admission day | |
| day = patient.admission_day if patient.admission_day is not None else 0 | |
| feature_vector.append(day) | |
| # Admission month | |
| month = patient.admission_month if patient.admission_month is not None else 0 | |
| feature_vector.append(month) | |
| # Admission year | |
| year = patient.admission_year if patient.admission_year is not None else 4 | |
| feature_vector.append(year) | |
| # Ensure we have the right number of features | |
| if len(feature_vector) != INPUT_DIM: | |
| # Pad or truncate to match input dimension | |
| while len(feature_vector) < INPUT_DIM: | |
| feature_vector.append(0.0) | |
| feature_vector = feature_vector[:INPUT_DIM] | |
| # Convert to array and scale | |
| data = np.array([feature_vector]) | |
| scaled_data = scaler.transform(data) | |
| tensor_data = torch.tensor(scaled_data, dtype=torch.float32) | |
| with torch.no_grad(): | |
| mu, logvar = model.encode(tensor_data) | |
| # Convert numpy types to Python native types for JSON serialization | |
| response = { | |
| "latent_mean": convert_numpy_to_python(mu.numpy().tolist()), | |
| "latent_logvar": convert_numpy_to_python(logvar.numpy().tolist()), | |
| "features_used": feature_names, | |
| "feature_values": convert_numpy_to_python(feature_vector) | |
| } | |
| return response | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Encoding failed: {str(e)}") | |
| def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": True, | |
| "input_dim": INPUT_DIM, | |
| "latent_dim": LATENT_DIM | |
| } | |
| async def upload_data(file: UploadFile = File(...)): | |
| """Upload a CSV file for continual training.""" | |
| os.makedirs("data", exist_ok=True) | |
| file_location = "data/new_data.csv" | |
| with open(file_location, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| return {"status": "success", "filename": file.filename} | |
| def get_training_progress(): | |
| """Get the latest training progress metrics for the web interface.""" | |
| progress_file = "data/training_progress.json" | |
| if not os.path.exists(progress_file): | |
| return JSONResponse(content={"status": "no_progress", "message": "No training progress found."}, status_code=404) | |
| with open(progress_file, "r") as f: | |
| progress = json.load(f) | |
| return JSONResponse(content=progress) | |
| def dashboard(): | |
| html = ''' | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <title>Training Progress Dashboard</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; margin: 2em; background: #f9f9f9; } | |
| h1 { color: #2c3e50; } | |
| #progress { background: #fff; padding: 1em; border-radius: 8px; box-shadow: 0 2px 8px #eee; max-width: 400px; } | |
| .label { color: #888; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Training Progress</h1> | |
| <div id="progress"> | |
| <div><span class="label">Epoch:</span> <span id="epoch">-</span></div> | |
| <div><span class="label">Train Loss:</span> <span id="train_loss">-</span></div> | |
| <div><span class="label">Val Loss:</span> <span id="val_loss">-</span></div> | |
| <div><span class="label">Best Val Loss:</span> <span id="best_val_loss">-</span></div> | |
| <div><span class="label">Last Updated:</span> <span id="timestamp">-</span></div> | |
| </div> | |
| <script> | |
| async function fetchProgress() { | |
| try { | |
| const res = await fetch('/training_progress'); | |
| if (!res.ok) throw new Error('No progress yet'); | |
| const data = await res.json(); | |
| document.getElementById('epoch').textContent = data.epoch; | |
| document.getElementById('train_loss').textContent = data.train_loss?.toFixed(4); | |
| document.getElementById('val_loss').textContent = data.val_loss?.toFixed(4); | |
| document.getElementById('best_val_loss').textContent = data.best_val_loss?.toFixed(4); | |
| const date = new Date(data.timestamp * 1000); | |
| document.getElementById('timestamp').textContent = date.toLocaleString(); | |
| } catch (e) { | |
| document.getElementById('progress').innerHTML = '<b>No training progress yet.</b>'; | |
| } | |
| } | |
| fetchProgress(); | |
| setInterval(fetchProgress, 3000); | |
| </script> | |
| </body> | |
| </html> | |
| ''' | |
| return HTMLResponse(content=html) | |
| # Run with: uvicorn src.api:app --reload --host 0.0.0.0 --port 8000 |