DGX_AI / app.py
vasiuuu's picture
Add /smoke endpoint for quick pipeline testing
4e35ea8
from fastapi import FastAPI, BackgroundTasks
import subprocess
import threading
import os
from fastapi.responses import PlainTextResponse
app = FastAPI()
# Global state to track training
training_status = {
"status": "idle",
"log": ""
}
def run_training(steps: int = 300):
global training_status
training_status["status"] = "running"
training_status["log"] = f"Started training for {steps} steps...\n"
# Run the trainer script and capture output
process = subprocess.Popen(
["python", "-m", "trainer.train", "--steps", str(steps)],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True
)
for line in process.stdout:
training_status["log"] += line
process.wait()
training_status["status"] = f"finished with code {process.returncode}"
training_status["log"] += f"\nTraining finished with exit code: {process.returncode}\n"
@app.get("/")
def read_root():
return {
"message": "CodeForge GRPO Training Node Active",
"status": training_status["status"],
"endpoints": {
"/start": "Start full training (300 steps)",
"/smoke": "Run smoke test (5 steps)",
"/logs": "View live training logs"
}
}
@app.post("/start")
def start_training(background_tasks: BackgroundTasks):
if training_status["status"] == "running":
return {"message": "Training is already running!"}
# Run full 300 steps in background
thread = threading.Thread(target=run_training, args=(300,))
thread.start()
return {"message": "Full training started! Go to /logs to monitor."}
@app.post("/smoke")
def start_smoke_test(background_tasks: BackgroundTasks):
if training_status["status"] == "running":
return {"message": "Training is already running!"}
# Run just 5 steps for a quick smoke test
thread = threading.Thread(target=run_training, args=(5,))
thread.start()
return {"message": "Smoke test (5 steps) started! Go to /logs to monitor."}
@app.get("/logs", response_class=PlainTextResponse)
def get_logs():
return training_status["log"]