File size: 3,574 Bytes
8c9c9e7
 
3d91615
 
8c9c9e7
 
 
350b2c3
3d91615
 
8c9c9e7
 
 
 
 
3d91615
 
 
 
 
 
 
 
 
 
 
 
8c9c9e7
 
6e2b7e0
 
 
 
 
350b2c3
 
 
 
 
 
 
 
 
 
 
 
6e2b7e0
350b2c3
 
 
 
 
 
 
 
 
 
 
 
 
8c9c9e7
6e2b7e0
3d91615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c9c9e7
3d91615
 
 
 
8c9c9e7
3d91615
 
 
 
 
 
 
 
 
 
 
8c9c9e7
 
 
3d91615
 
8c9c9e7
3d91615
 
8c9c9e7
 
 
3d91615
 
 
 
 
8c9c9e7
3d91615
8c9c9e7
 
3d91615
8c9c9e7
3d91615
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import logging
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import joblib
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from pydantic import BaseModel
import uvicorn

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="Health Monitoring System",
             description="A FastAPI application for health monitoring and prediction",
             version="1.0.0")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load models
def load_models():
    global heart_model, autoencoder
    heart_model = None
    autoencoder = None
    
    try:
        # Download and load heart model
        logger.info("Downloading heart model from Hugging Face Hub...")
        heart_model_path = hf_hub_download(
            repo_id="leo861/app",
            filename="heart/models/heart_model.joblib",
            cache_dir="models"
        )
        heart_model = joblib.load(heart_model_path)
        logger.info("Heart model loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load heart model: {str(e)}")
    
    try:
        # Download and load autoencoder
        logger.info("Downloading autoencoder from Hugging Face Hub...")
        autoencoder_path = hf_hub_download(
            repo_id="leo861/app",
            filename="models/best_model.pth",
            cache_dir="models"
        )
        autoencoder = torch.load(autoencoder_path)
        autoencoder.eval()
        logger.info("Autoencoder model loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load autoencoder: {str(e)}")

# Load models on startup
@app.on_event("startup")
async def startup_event():
    logger.info("Loading trained models...")
    try:
        load_models()
    except Exception as e:
        logger.error(f"Error loading models: {str(e)}")

# Define request models
class PredictionRequest(BaseModel):
    data: dict

# Define response models
class HealthResponse(BaseModel):
    status: str
    models: dict

class PredictionResponse(BaseModel):
    status: str
    prediction: str
    message: str = None

@app.get("/")
async def root():
    return {"message": "Welcome to the Health Monitoring System API"}

@app.get("/health", response_model=HealthResponse)
async def health():
    return {
        "status": "healthy",
        "models": {
            "heart_model": heart_model is not None,
            "autoencoder": autoencoder is not None
        }
    }

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    try:
        if not request.data:
            raise HTTPException(status_code=400, detail="No data provided")
        
        # Add your prediction logic here
        logger.info("Processing prediction request")
        result = {
            "status": "success",
            "prediction": "normal",
            "message": "Prediction completed successfully"
        }
        logger.info(f"Prediction completed: {result}")
        return result
    except Exception as e:
        logger.error(f"Error during prediction: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)