#!/usr/bin/env python3 """ EXAONE Fine-tuning Space FastAPI 애플리케이션 """ import os import json import subprocess import asyncio from pathlib import Path from typing import Dict, Any import logging from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.responses import StreamingResponse from pydantic import BaseModel import uvicorn # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="EXAONE Fine-tuning", description="EXAONE 4.0 1.2B 모델 파인튜닝 API", version="1.0.0" ) # 전역 변수 training_status = { "is_running": False, "progress": 0, "current_epoch": 0, "total_epochs": 3, "loss": 0.0, "status": "idle" } class TrainingRequest(BaseModel): model_name: str = "amis5895/exaone-1p2b-nutrition-kdri" dataset_path: str = "/app/data" config_path: str = "/app/autotrain_ultra_low_final.yaml" @app.get("/") async def root(): """루트 엔드포인트""" return { "message": "EXAONE Fine-tuning API", "status": "running", "version": "1.0.0" } @app.post("/start_training") async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks): """학습 시작""" global training_status if training_status["is_running"]: raise HTTPException(status_code=400, detail="Training is already running") training_status.update({ "is_running": True, "progress": 0, "current_epoch": 0, "status": "starting" }) # 백그라운드에서 학습 시작 background_tasks.add_task(run_training, request) return { "message": "Training started", "status": "starting", "model_name": request.model_name } async def run_training(request: TrainingRequest): """실제 학습 실행""" global training_status try: logger.info("Starting training process...") training_status["status"] = "running" # AutoTrain 명령어 실행 cmd = [ "autotrain", "llm", "--train", "--project_name", "exaone-finetuning", "--model", "LGAI-EXAONE/EXAONE-4.0-1.2B", "--data_path", request.dataset_path, "--text_column", "text", "--use_peft", "--quantization", "int4", "--lora_r", "16", "--lora_alpha", "32", "--lora_dropout", "0.05", "--target_modules", "all-linear", "--epochs", "3", "--batch_size", "4", "--gradient_accumulation", "4", "--learning_rate", "2e-4", "--warmup_ratio", "0.03", "--mixed_precision", "fp16", "--push_to_hub", "--hub_model_id", request.model_name, "--username", "amis5895" ] process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True ) # 학습 진행 상황 모니터링 for line in process.stdout: logger.info(line.strip()) # 진행률 파싱 (간단한 예시) if "epoch" in line.lower(): training_status["current_epoch"] += 1 training_status["progress"] = (training_status["current_epoch"] / training_status["total_epochs"]) * 100 if "loss" in line.lower(): try: # 손실값 추출 (간단한 예시) parts = line.split() for i, part in enumerate(parts): if part == "loss" and i + 1 < len(parts): training_status["loss"] = float(parts[i + 1]) break except: pass process.wait() if process.returncode == 0: training_status.update({ "is_running": False, "progress": 100, "status": "completed" }) logger.info("Training completed successfully!") else: training_status.update({ "is_running": False, "status": "failed" }) logger.error("Training failed!") except Exception as e: logger.error(f"Training error: {str(e)}") training_status.update({ "is_running": False, "status": "error" }) @app.get("/status") async def get_status(): """학습 상태 조회""" return training_status @app.get("/logs") async def get_logs(): """로그 조회""" log_file = Path("/app/training.log") if log_file.exists(): with open(log_file, "r", encoding="utf-8") as f: logs = f.read() return {"logs": logs} else: return {"logs": "No logs available"} @app.get("/logs/stream") async def stream_logs(): """실시간 로그 스트리밍""" def generate_logs(): log_file = Path("/app/training.log") if log_file.exists(): with open(log_file, "r", encoding="utf-8") as f: for line in f: yield f"data: {line}\n\n" else: yield "data: No logs available\n\n" return StreamingResponse(generate_logs(), media_type="text/plain") @app.post("/stop_training") async def stop_training(): """학습 중지""" global training_status if not training_status["is_running"]: raise HTTPException(status_code=400, detail="No training is running") # 학습 프로세스 중지 (간단한 예시) training_status.update({ "is_running": False, "status": "stopped" }) return {"message": "Training stopped"} @app.get("/health") async def health_check(): """헬스 체크""" return {"status": "healthy", "timestamp": "2024-01-01T00:00:00Z"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)