Spaces:
Running
Running
Fix: add health server on port 7860 to prevent timeout
Browse files- train_on_hf.py +41 -1
train_on_hf.py
CHANGED
|
@@ -351,10 +351,40 @@ def merge_and_push(hf_token: str):
|
|
| 351 |
print("Pushed to https://huggingface.co/Rayugacodes/kernelx-strategist")
|
| 352 |
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
def main():
|
|
|
|
|
|
|
|
|
|
| 355 |
parser = argparse.ArgumentParser(description="KernelX GPU Training on HF")
|
| 356 |
parser.add_argument("--hf-token", required=True, help="HuggingFace token")
|
| 357 |
-
parser.add_argument("--world-model-samples", type=int, default=
|
| 358 |
parser.add_argument("--strategist-samples", type=int, default=10000)
|
| 359 |
parser.add_argument("--skip-world-model", action="store_true")
|
| 360 |
parser.add_argument("--skip-strategist", action="store_true")
|
|
@@ -362,22 +392,32 @@ def main():
|
|
| 362 |
args = parser.parse_args()
|
| 363 |
|
| 364 |
# Setup
|
|
|
|
| 365 |
data_dir = setup(args.hf_token)
|
| 366 |
|
| 367 |
# Train
|
| 368 |
if not args.skip_world_model:
|
|
|
|
| 369 |
train_world_model(data_dir, max_samples=args.world_model_samples)
|
| 370 |
|
| 371 |
if not args.skip_strategist:
|
|
|
|
| 372 |
train_strategist(data_dir, max_samples=args.strategist_samples)
|
| 373 |
|
| 374 |
if not args.skip_merge:
|
|
|
|
| 375 |
merge_and_push(args.hf_token)
|
| 376 |
|
|
|
|
| 377 |
print("\n=== All done! ===")
|
| 378 |
print("Model: https://huggingface.co/Rayugacodes/kernelx-strategist")
|
| 379 |
print("Next: convert to GGUF for sub-50ms CPU inference")
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
if __name__ == "__main__":
|
| 383 |
main()
|
|
|
|
| 351 |
print("Pushed to https://huggingface.co/Rayugacodes/kernelx-strategist")
|
| 352 |
|
| 353 |
|
| 354 |
+
def start_health_server():
|
| 355 |
+
"""Start a dummy HTTP server on port 7860 so HF Spaces doesn't kill us."""
|
| 356 |
+
from http.server import HTTPServer, BaseHTTPRequestHandler
|
| 357 |
+
import threading
|
| 358 |
+
|
| 359 |
+
status = {"stage": "starting"}
|
| 360 |
+
|
| 361 |
+
class Handler(BaseHTTPRequestHandler):
|
| 362 |
+
def do_GET(self):
|
| 363 |
+
self.send_response(200)
|
| 364 |
+
self.send_header("Content-Type", "text/html")
|
| 365 |
+
self.end_headers()
|
| 366 |
+
self.wfile.write(
|
| 367 |
+
f"<html><body><h1>KernelX Training</h1>"
|
| 368 |
+
f"<p>Stage: <b>{status['stage']}</b></p>"
|
| 369 |
+
f"<p>Refresh to check progress.</p></body></html>".encode()
|
| 370 |
+
)
|
| 371 |
+
def log_message(self, format, *args):
|
| 372 |
+
pass # suppress request logs
|
| 373 |
+
|
| 374 |
+
server = HTTPServer(("0.0.0.0", 7860), Handler)
|
| 375 |
+
t = threading.Thread(target=server.serve_forever, daemon=True)
|
| 376 |
+
t.start()
|
| 377 |
+
print("Health server running on port 7860")
|
| 378 |
+
return status
|
| 379 |
+
|
| 380 |
+
|
| 381 |
def main():
|
| 382 |
+
# Start health server FIRST so HF doesn't kill us
|
| 383 |
+
status = start_health_server()
|
| 384 |
+
|
| 385 |
parser = argparse.ArgumentParser(description="KernelX GPU Training on HF")
|
| 386 |
parser.add_argument("--hf-token", required=True, help="HuggingFace token")
|
| 387 |
+
parser.add_argument("--world-model-samples", type=int, default=10000)
|
| 388 |
parser.add_argument("--strategist-samples", type=int, default=10000)
|
| 389 |
parser.add_argument("--skip-world-model", action="store_true")
|
| 390 |
parser.add_argument("--skip-strategist", action="store_true")
|
|
|
|
| 392 |
args = parser.parse_args()
|
| 393 |
|
| 394 |
# Setup
|
| 395 |
+
status["stage"] = "downloading data"
|
| 396 |
data_dir = setup(args.hf_token)
|
| 397 |
|
| 398 |
# Train
|
| 399 |
if not args.skip_world_model:
|
| 400 |
+
status["stage"] = "training world model"
|
| 401 |
train_world_model(data_dir, max_samples=args.world_model_samples)
|
| 402 |
|
| 403 |
if not args.skip_strategist:
|
| 404 |
+
status["stage"] = "training strategist"
|
| 405 |
train_strategist(data_dir, max_samples=args.strategist_samples)
|
| 406 |
|
| 407 |
if not args.skip_merge:
|
| 408 |
+
status["stage"] = "merging and pushing to HF"
|
| 409 |
merge_and_push(args.hf_token)
|
| 410 |
|
| 411 |
+
status["stage"] = "DONE"
|
| 412 |
print("\n=== All done! ===")
|
| 413 |
print("Model: https://huggingface.co/Rayugacodes/kernelx-strategist")
|
| 414 |
print("Next: convert to GGUF for sub-50ms CPU inference")
|
| 415 |
|
| 416 |
+
# Keep alive so the Space stays up
|
| 417 |
+
import time
|
| 418 |
+
while True:
|
| 419 |
+
time.sleep(60)
|
| 420 |
+
|
| 421 |
|
| 422 |
if __name__ == "__main__":
|
| 423 |
main()
|