Update app.py
Browse files
app.py
CHANGED
|
@@ -180,37 +180,29 @@ class BatchPredictionResponse(BaseModel):
|
|
| 180 |
async def root():
|
| 181 |
return {"message": "BERT Compliance Predictor API"}
|
| 182 |
|
| 183 |
-
@app.get("/health")
|
| 184 |
async def health_check():
|
| 185 |
return {"status": "healthy"}
|
| 186 |
|
| 187 |
-
@app.get("/training-status")
|
| 188 |
async def get_training_status():
|
| 189 |
return training_status
|
| 190 |
|
| 191 |
-
@app.post("/
|
| 192 |
-
async def upload_file(file: UploadFile = File(...)):
|
| 193 |
-
"""Upload a CSV file for training or validation"""
|
| 194 |
-
if not file.filename.endswith('.csv'):
|
| 195 |
-
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
| 196 |
-
|
| 197 |
-
file_path = UPLOAD_DIR / file.filename
|
| 198 |
-
with file_path.open("wb") as buffer:
|
| 199 |
-
shutil.copyfileobj(file.file, buffer)
|
| 200 |
-
|
| 201 |
-
return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
|
| 202 |
-
|
| 203 |
-
@app.post("/bert/train", response_model=TrainingResponse)
|
| 204 |
async def start_training(
|
| 205 |
config: TrainingConfig,
|
| 206 |
background_tasks: BackgroundTasks,
|
| 207 |
-
|
| 208 |
):
|
| 209 |
if training_status["is_training"]:
|
| 210 |
raise HTTPException(status_code=400, detail="Training is already in progress")
|
| 211 |
|
| 212 |
-
if not
|
| 213 |
-
raise HTTPException(status_code=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 216 |
|
|
@@ -222,9 +214,9 @@ async def start_training(
|
|
| 222 |
"status": "starting"
|
| 223 |
})
|
| 224 |
|
| 225 |
-
background_tasks.add_task(train_model_task, config, file_path, training_id)
|
| 226 |
|
| 227 |
-
download_url = f"/bert/download-model/{training_id}"
|
| 228 |
|
| 229 |
return TrainingResponse(
|
| 230 |
message="Training started successfully",
|
|
@@ -233,7 +225,7 @@ async def start_training(
|
|
| 233 |
download_url=download_url
|
| 234 |
)
|
| 235 |
|
| 236 |
-
@app.post("/bert/validate")
|
| 237 |
async def validate_model(
|
| 238 |
file: UploadFile = File(...),
|
| 239 |
model_name: str = "BERT_model"
|
|
@@ -319,7 +311,7 @@ async def validate_model(
|
|
| 319 |
if os.path.exists(file_path):
|
| 320 |
os.remove(file_path)
|
| 321 |
|
| 322 |
-
@app.post("/bert/predict")
|
| 323 |
async def predict(
|
| 324 |
request: Optional[PredictionRequest] = None,
|
| 325 |
file: Optional[UploadFile] = File(None),
|
|
@@ -492,7 +484,7 @@ async def predict(
|
|
| 492 |
except Exception as e:
|
| 493 |
raise HTTPException(status_code=500, detail=str(e))
|
| 494 |
|
| 495 |
-
@app.get("/bert/download-model/{model_id}")
|
| 496 |
async def download_model(model_id: str):
|
| 497 |
"""Download a trained model"""
|
| 498 |
model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
|
|
|
|
| 180 |
async def root():
|
| 181 |
return {"message": "BERT Compliance Predictor API"}
|
| 182 |
|
| 183 |
+
@app.get("/v1/bert/health")
|
| 184 |
async def health_check():
|
| 185 |
return {"status": "healthy"}
|
| 186 |
|
| 187 |
+
@app.get("/v1/bert/training-status")
|
| 188 |
async def get_training_status():
|
| 189 |
return training_status
|
| 190 |
|
| 191 |
+
@app.post("/v1/bert/train", response_model=TrainingResponse)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
async def start_training(
|
| 193 |
config: TrainingConfig,
|
| 194 |
background_tasks: BackgroundTasks,
|
| 195 |
+
file: UploadFile = File(...)
|
| 196 |
):
|
| 197 |
if training_status["is_training"]:
|
| 198 |
raise HTTPException(status_code=400, detail="Training is already in progress")
|
| 199 |
|
| 200 |
+
if not file.filename.endswith('.csv'):
|
| 201 |
+
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
|
| 202 |
+
|
| 203 |
+
file_path = UPLOAD_DIR / file.filename
|
| 204 |
+
with file_path.open("wb") as buffer:
|
| 205 |
+
shutil.copyfileobj(file.file, buffer)
|
| 206 |
|
| 207 |
training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 208 |
|
|
|
|
| 214 |
"status": "starting"
|
| 215 |
})
|
| 216 |
|
| 217 |
+
background_tasks.add_task(train_model_task, config, str(file_path), training_id)
|
| 218 |
|
| 219 |
+
download_url = f"/v1/bert/download-model/{training_id}"
|
| 220 |
|
| 221 |
return TrainingResponse(
|
| 222 |
message="Training started successfully",
|
|
|
|
| 225 |
download_url=download_url
|
| 226 |
)
|
| 227 |
|
| 228 |
+
@app.post("/v1/bert/validate")
|
| 229 |
async def validate_model(
|
| 230 |
file: UploadFile = File(...),
|
| 231 |
model_name: str = "BERT_model"
|
|
|
|
| 311 |
if os.path.exists(file_path):
|
| 312 |
os.remove(file_path)
|
| 313 |
|
| 314 |
+
@app.post("/v1/bert/predict")
|
| 315 |
async def predict(
|
| 316 |
request: Optional[PredictionRequest] = None,
|
| 317 |
file: Optional[UploadFile] = File(None),
|
|
|
|
| 484 |
except Exception as e:
|
| 485 |
raise HTTPException(status_code=500, detail=str(e))
|
| 486 |
|
| 487 |
+
@app.get("/v1/bert/download-model/{model_id}")
|
| 488 |
async def download_model(model_id: str):
|
| 489 |
"""Download a trained model"""
|
| 490 |
model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
|