File size: 4,418 Bytes
f73ae00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import sys
import time
import subprocess
import json
import urllib.request

# ==========================================================
# Terminal Agent (Gemini API) for BS-RoKAN 監視
# VRAM消費: 0GB / CPU負荷: 極小
# ==========================================================

# APIキーをファイルから読み込む
KEY_FILE = "APIKey From Google AI Studio.txt"
if os.path.exists(KEY_FILE):
    with open(KEY_FILE, "r") as f:
        API_KEY = f.read().strip()
else:
    API_KEY = os.environ.get("GEMINI_API_KEY", "")

MODEL_NAME = "gemini-3.1-flash-lite"

def analyze_logs_with_llm(log_buffer):
    if not API_KEY:
        print("[Agent] API_KEYがないため判定をスキップ(OK)")
        return "OK"
        
    system_instruction = "あなたは音声分離モデルBS-RoKANの学習監視エージェントです。以下の学習ログを見て、学習が順調か評価してください。"
    prompt = f"{system_instruction} 出力は OK, LOWER_LR, RESTART のいずれか1語のみにしてください。 \n\nログ:\n" + "\n".join(log_buffer)
    
    url = f"https://generativelanguage.googleapis.com/v1beta/models/{MODEL_NAME}:generateContent?key={API_KEY}"
    
    # Gemini API (REST) format
    payload = {
        "contents": [{
            "parts": [{"text": prompt}]
        }],
        "generationConfig": {
            "temperature": 0.1,
            "maxOutputTokens": 10
        }
    }
    
    try:
        req = urllib.request.Request(url, data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"})
        with urllib.request.urlopen(req, timeout=15) as r:
            response = json.loads(r.read())
            # Extract text from Gemini response structure
            decision = response["candidates"][0]["content"]["parts"][0]["text"].strip().upper()
            
            if "LOWER_LR" in decision: return "LOWER_LR"
            if "RESTART" in decision: return "RESTART"
            return "OK"
    except Exception as e:
        print(f"[Agent] Gemini APIエラー: {e}")
        return "OK"

def main():
    print(f"[*] Gemini Terminal Agent 起動成功 (Model: {MODEL_NAME})")
    print(f"[*] 学習プロセスを起動中...")
    
    # RX 9070 XT想定: WSL2上でバッチサイズ2で開始
    cmd = ["python", "-u", "train_rokan.py", "--batch_size", "2"] 
    
    while True:
        print(f"\n[Agent] 訓練開始: {' '.join(cmd)}")
        process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
        log_buffer = []
        
        try:
            for line in process.stdout:
                line = line.strip()
                if not line: continue
                print(line)
                
                if "Loss" in line or "Saved:" in line:
                    log_buffer.append(line)
                    
                # セーブ(Epoch終了)ごとにGeminiで診断を行う
                if "Saved:" in line and len(log_buffer) > 5:
                    decision = analyze_logs_with_llm(log_buffer[-30:])
                    if decision == "LOWER_LR":
                        print(f"[Agent] Geminiの判定: {decision} (学習率を下げて再開します)")
                        process.terminate()
                        if "--gate_lr" not in cmd: 
                            cmd.extend(["--gate_lr", "5e-4"]) # 1e-3 -> 5e-4
                        break 
                    elif decision == "RESTART":
                        print(f"[Agent] Geminiの判定: {decision} (異常検知につき再起動します)")
                        process.terminate()
                        time.sleep(5)
                        break 
                    else:
                        print(f"[Agent] Geminiの判定: {decision} (順調です)")
                        log_buffer = [] # バッファをクリア
                        
        except KeyboardInterrupt:
            print("\n[Agent] ユーザーによる中断。プロセスを終了します。")
            process.terminate()
            sys.exit(0)
            
        process.wait()
        if process.returncode != 0 and process.returncode is not None:
            print(f"[Agent] 訓練プロセスが終了しました (Code: {process.returncode})。10秒後に再起動を試みます。")
            time.sleep(10)

if __name__ == "__main__":
    main()