File size: 5,501 Bytes
f725085
 
 
 
 
 
 
1bc4593
f725085
 
 
 
 
 
1bc4593
 
f725085
 
 
 
1bc4593
 
f725085
 
 
 
 
 
 
 
 
 
 
 
2a044a1
9ee11ef
f725085
1bc4593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f725085
 
 
 
9ee11ef
f725085
 
 
9ee11ef
f725085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a044a1
f725085
 
 
 
 
 
 
 
 
2a044a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f725085
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#main.py
from fastapi import FastAPI, UploadFile, File, HTTPException, Request,Query

from pydantic import BaseModel
from typing import Union
import base64

from app.model import load_model, predict_with_model,get_heatmap,cache
import os
import threading
import time
from app.utils import heartbeat,register_forever
from app.log import logger
from app.config import MODEL_NAME, ENV,MODEL_TYPE
from typing import Optional
import psutil
logger.info(f"ENV :{ENV}")

app = FastAPI()

start_time = time.time()
request_count = 0

def load_models_once():
    _ = load_model()

@app.on_event("startup")
def startup():
    load_models_once()
    threading.Thread(target=register_forever, daemon=True).start()
    threading.Thread(target=heartbeat, daemon=True).start()


class ImagePayload(BaseModel):
    image: str
    predicted_class_index: int




@app.middleware("http")
async def count_requests(request, call_next):
    global request_count
    if request.url.path not in ["/admin/clear-predictions", "/admin/clear-heatmaps", "/admin/stats", "/admin/logs", "/admin/reset-state"]:
        request_count += 1
    return await call_next(request)


@app.post("/admin/clear-predictions")
def clear_predictions():
    count = 0
    for key in list(cache.iterkeys()):
        if key.endswith("_pred"):
            del cache[key]
            count += 1
    return {"message": f"✅ {count} prédictions supprimées"}

@app.post("/admin/clear-heatmaps")
def clear_heatmaps():
    count = 0
    for key in list(cache.iterkeys()):
        if key.endswith("_heatmap"):
            del cache[key]
            count += 1
    return {"message": f"🔥 {count} heatmaps supprimées"}



@app.get("/admin/stats")
def get_stats():
    uptime_sec = int(time.time() - start_time)
    process = psutil.Process()
    mem_mb = process.memory_info().rss / 1024**2
    return {
        "uptime_seconds": uptime_sec,
        "uptime_human": time.strftime("%H:%M:%S", time.gmtime(uptime_sec)),
        "request_count": request_count,
        "cache_items": len(cache),
        "memory_usage_mb": f"{mem_mb:.2f}",
    }


@app.get("/admin/logs")
def get_logs(lines: Optional[int] = Query(50, ge=1, le=500)):
    log_path = "app.log"  # ou le chemin réel de ton fichier log
    if not os.path.exists(log_path):
        return {"logs": ["Aucun fichier log disponible."]}
    
    with open(log_path, "r") as f:
        all_lines = f.readlines()
        return {"logs": all_lines[-lines:]}



@app.post("/predict")
async def predict(request: Request,
    file: UploadFile = File(None),
    payload: Union[ImagePayload, None] = None,
    #show_heatmap: bool = Query(False, description="Afficher la heatmap"),
    ):

    logger.info("🔁 Requête reçue")
    #logger.info(f"✅ Show heatmap : {show_heatmap}")
    
    try:
        # Cas 1 : multipart avec fichier
        if file is not None:
            image_bytes = await file.read()
            logger.debug(f"✅ Image reçue via multipart : {file.filename}{len(image_bytes)} octets")

        # Cas 2 : JSON base64
        elif payload is not None:
             image_bytes = base64.b64decode(payload.image)
             logger.debug(f"✅ Image décodée depuis base64 : {len(image_bytes)} octets)")

        else:
            logger.info("⚠️ Aucune image reçue")
            raise HTTPException(status_code=400, detail="Format de requête non supporté.")

        # Appel de ta logique de prédiction
        models = load_model()
        if not models:
            raise HTTPException(status_code=500, detail="Aucun modèle chargé.")
        model_config = models[0]
        prediction = predict_with_model(model_config, image_bytes)

        # Pour l’instant : réponse simulée
        return prediction

    except Exception as e:
        logger.error("❌ Une erreur s'est produite", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))





@app.post("/heatmap")
async def predict_heatmap(
    request: Request,
    payload: Union[ImagePayload, None] = None,
    file: UploadFile = File(None),
    predicted_class_index: int = Query(None)
):
    logger.info("🔁 Requête reçue pour heatmap")
    
    try:
        if file is not None:
            image_bytes = await file.read()
            logger.debug(f"✅ Image reçue via multipart : {file.filename}{len(image_bytes)} octets")
            if predicted_class_index is None:
                raise HTTPException(status_code=400, detail="predicted_class_index requis en query avec fichier multipart")

        elif payload is not None:
            image_bytes = base64.b64decode(payload.image)
            predicted_class_index = payload.predicted_class_index
            logger.debug(f"✅ Image reçue en JSON base64 : {len(image_bytes)} octets, class={predicted_class_index}")

        else:
            raise HTTPException(status_code=400, detail="Aucune image reçue")

        models = load_model()
        if not models:
            raise HTTPException(status_code=500, detail="Aucun modèle chargé.")
        model_config = models[0]

        heatmap = get_heatmap(model_config, image_bytes, predicted_class_index)

        return {"heatmap": heatmap}

    except Exception as e:
        logger.error("❌ Erreur heatmap", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))



@app.get("/health")
def health_check():
    return {
    "status": "ok",
    "model_name": MODEL_NAME,
    "model_type":MODEL_TYPE,
    "timestamp": time.time()
}