Update app.py
Browse files
app.py
CHANGED
|
@@ -184,11 +184,11 @@ async def root():
|
|
| 184 |
async def health_check():
|
| 185 |
return {"status": "healthy"}
|
| 186 |
|
| 187 |
-
@app.get("/
|
| 188 |
async def get_training_status():
|
| 189 |
return training_status
|
| 190 |
|
| 191 |
-
@app.post("/
|
| 192 |
async def upload_file(file: UploadFile = File(...)):
|
| 193 |
"""Upload a CSV file for training or validation"""
|
| 194 |
if not file.filename.endswith('.csv'):
|
|
@@ -200,7 +200,7 @@ async def upload_file(file: UploadFile = File(...)):
|
|
| 200 |
|
| 201 |
return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
|
| 202 |
|
| 203 |
-
@app.post("/
|
| 204 |
async def start_training(
|
| 205 |
config: TrainingConfig,
|
| 206 |
background_tasks: BackgroundTasks,
|
|
@@ -224,7 +224,7 @@ async def start_training(
|
|
| 224 |
|
| 225 |
background_tasks.add_task(train_model_task, config, file_path, training_id)
|
| 226 |
|
| 227 |
-
download_url = f"/
|
| 228 |
|
| 229 |
return TrainingResponse(
|
| 230 |
message="Training started successfully",
|
|
@@ -233,7 +233,7 @@ async def start_training(
|
|
| 233 |
download_url=download_url
|
| 234 |
)
|
| 235 |
|
| 236 |
-
@app.post("/
|
| 237 |
async def validate_model(
|
| 238 |
file: UploadFile = File(...),
|
| 239 |
model_name: str = "BERT_model"
|
|
@@ -319,7 +319,7 @@ async def validate_model(
|
|
| 319 |
if os.path.exists(file_path):
|
| 320 |
os.remove(file_path)
|
| 321 |
|
| 322 |
-
@app.post("/
|
| 323 |
async def predict(
|
| 324 |
request: Optional[PredictionRequest] = None,
|
| 325 |
file: Optional[UploadFile] = File(None),
|
|
@@ -510,80 +510,51 @@ async def train_model_task(config: TrainingConfig, file_path: str, training_id:
|
|
| 510 |
data_df_original, label_encoders = load_and_preprocess_data(file_path)
|
| 511 |
save_label_encoders(label_encoders)
|
| 512 |
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
test_size=config.test_size,
|
| 516 |
-
random_state=config.random_state,
|
| 517 |
-
stratify=data_df_original[LABEL_COLUMNS[0]]
|
| 518 |
-
)
|
| 519 |
-
|
| 520 |
-
train_texts = train_df[TEXT_COLUMN]
|
| 521 |
-
val_texts = val_df[TEXT_COLUMN]
|
| 522 |
-
train_labels_array = train_df[LABEL_COLUMNS].values
|
| 523 |
-
val_labels_array = val_df[LABEL_COLUMNS].values
|
| 524 |
|
| 525 |
-
|
| 526 |
-
val_metadata_df = val_df[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in val_df.columns for col in METADATA_COLUMNS) else None
|
| 527 |
|
| 528 |
num_labels_list = get_num_labels(label_encoders)
|
| 529 |
tokenizer = get_tokenizer(config.model_name)
|
| 530 |
|
| 531 |
-
if
|
| 532 |
-
metadata_dim =
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
tokenizer,
|
| 538 |
-
config.max_length
|
| 539 |
-
)
|
| 540 |
-
val_dataset = ComplianceDatasetWithMetadata(
|
| 541 |
-
val_texts.tolist(),
|
| 542 |
-
val_metadata_df.values,
|
| 543 |
-
val_labels_array,
|
| 544 |
tokenizer,
|
| 545 |
config.max_length
|
| 546 |
)
|
| 547 |
model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
|
| 548 |
else:
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
tokenizer,
|
| 553 |
-
config.max_length
|
| 554 |
-
)
|
| 555 |
-
val_dataset = ComplianceDataset(
|
| 556 |
-
val_texts.tolist(),
|
| 557 |
-
val_labels_array,
|
| 558 |
tokenizer,
|
| 559 |
config.max_length
|
| 560 |
)
|
| 561 |
model = BertMultiOutputModel(num_labels_list).to(DEVICE)
|
| 562 |
|
| 563 |
-
train_loader = DataLoader(
|
| 564 |
-
val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
|
| 565 |
|
| 566 |
criterions = initialize_criterions(num_labels_list)
|
| 567 |
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
| 568 |
|
| 569 |
-
best_val_loss = float('inf')
|
| 570 |
for epoch in range(config.num_epochs):
|
| 571 |
training_status["current_epoch"] = epoch + 1
|
| 572 |
|
| 573 |
train_loss = train_model(model, train_loader, criterions, optimizer)
|
| 574 |
-
val_metrics, _, _ = evaluate_model(model, val_loader)
|
| 575 |
-
|
| 576 |
training_status["current_loss"] = train_loss
|
| 577 |
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
save_model(model, training_id)
|
| 581 |
|
| 582 |
training_status.update({
|
| 583 |
"is_training": False,
|
| 584 |
"end_time": datetime.now().isoformat(),
|
| 585 |
-
"status": "completed"
|
| 586 |
-
"metrics": summarize_metrics(val_metrics).to_dict()
|
| 587 |
})
|
| 588 |
|
| 589 |
except Exception as e:
|
|
|
|
| 184 |
async def health_check():
|
| 185 |
return {"status": "healthy"}
|
| 186 |
|
| 187 |
+
@app.get("/training-status")
|
| 188 |
async def get_training_status():
|
| 189 |
return training_status
|
| 190 |
|
| 191 |
+
@app.post("/upload")
|
| 192 |
async def upload_file(file: UploadFile = File(...)):
|
| 193 |
"""Upload a CSV file for training or validation"""
|
| 194 |
if not file.filename.endswith('.csv'):
|
|
|
|
| 200 |
|
| 201 |
return {"message": f"File {file.filename} uploaded successfully", "file_path": str(file_path)}
|
| 202 |
|
| 203 |
+
@app.post("/bert/train", response_model=TrainingResponse)
|
| 204 |
async def start_training(
|
| 205 |
config: TrainingConfig,
|
| 206 |
background_tasks: BackgroundTasks,
|
|
|
|
| 224 |
|
| 225 |
background_tasks.add_task(train_model_task, config, file_path, training_id)
|
| 226 |
|
| 227 |
+
download_url = f"/bert/download-model/{training_id}"
|
| 228 |
|
| 229 |
return TrainingResponse(
|
| 230 |
message="Training started successfully",
|
|
|
|
| 233 |
download_url=download_url
|
| 234 |
)
|
| 235 |
|
| 236 |
+
@app.post("/bert/validate")
|
| 237 |
async def validate_model(
|
| 238 |
file: UploadFile = File(...),
|
| 239 |
model_name: str = "BERT_model"
|
|
|
|
| 319 |
if os.path.exists(file_path):
|
| 320 |
os.remove(file_path)
|
| 321 |
|
| 322 |
+
@app.post("/bert/predict")
|
| 323 |
async def predict(
|
| 324 |
request: Optional[PredictionRequest] = None,
|
| 325 |
file: Optional[UploadFile] = File(None),
|
|
|
|
| 510 |
data_df_original, label_encoders = load_and_preprocess_data(file_path)
|
| 511 |
save_label_encoders(label_encoders)
|
| 512 |
|
| 513 |
+
texts = data_df_original[TEXT_COLUMN]
|
| 514 |
+
labels_array = data_df_original[LABEL_COLUMNS].values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
+
metadata_df = data_df_original[METADATA_COLUMNS] if METADATA_COLUMNS and all(col in data_df_original.columns for col in METADATA_COLUMNS) else None
|
|
|
|
| 517 |
|
| 518 |
num_labels_list = get_num_labels(label_encoders)
|
| 519 |
tokenizer = get_tokenizer(config.model_name)
|
| 520 |
|
| 521 |
+
if metadata_df is not None:
|
| 522 |
+
metadata_dim = metadata_df.shape[1]
|
| 523 |
+
dataset = ComplianceDatasetWithMetadata(
|
| 524 |
+
texts.tolist(),
|
| 525 |
+
metadata_df.values,
|
| 526 |
+
labels_array,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
tokenizer,
|
| 528 |
config.max_length
|
| 529 |
)
|
| 530 |
model = BertMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
|
| 531 |
else:
|
| 532 |
+
dataset = ComplianceDataset(
|
| 533 |
+
texts.tolist(),
|
| 534 |
+
labels_array,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
tokenizer,
|
| 536 |
config.max_length
|
| 537 |
)
|
| 538 |
model = BertMultiOutputModel(num_labels_list).to(DEVICE)
|
| 539 |
|
| 540 |
+
train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
|
|
|
|
| 541 |
|
| 542 |
criterions = initialize_criterions(num_labels_list)
|
| 543 |
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
| 544 |
|
|
|
|
| 545 |
for epoch in range(config.num_epochs):
|
| 546 |
training_status["current_epoch"] = epoch + 1
|
| 547 |
|
| 548 |
train_loss = train_model(model, train_loader, criterions, optimizer)
|
|
|
|
|
|
|
| 549 |
training_status["current_loss"] = train_loss
|
| 550 |
|
| 551 |
+
# Save model after each epoch
|
| 552 |
+
save_model(model, training_id)
|
|
|
|
| 553 |
|
| 554 |
training_status.update({
|
| 555 |
"is_training": False,
|
| 556 |
"end_time": datetime.now().isoformat(),
|
| 557 |
+
"status": "completed"
|
|
|
|
| 558 |
})
|
| 559 |
|
| 560 |
except Exception as e:
|