| import os |
| import sys |
| import time |
| import subprocess |
| import json |
| import urllib.request |
|
|
| |
| |
| |
| |
|
|
| |
| KEY_FILE = "APIKey From Google AI Studio.txt" |
| if os.path.exists(KEY_FILE): |
| with open(KEY_FILE, "r") as f: |
| API_KEY = f.read().strip() |
| else: |
| API_KEY = os.environ.get("GEMINI_API_KEY", "") |
|
|
| MODEL_NAME = "gemini-3.1-flash-lite" |
|
|
| def analyze_logs_with_llm(log_buffer): |
| if not API_KEY: |
| print("[Agent] API_KEYがないため判定をスキップ(OK)") |
| return "OK" |
| |
| system_instruction = "あなたは音声分離モデルBS-RoKANの学習監視エージェントです。以下の学習ログを見て、学習が順調か評価してください。" |
| prompt = f"{system_instruction} 出力は OK, LOWER_LR, RESTART のいずれか1語のみにしてください。 \n\nログ:\n" + "\n".join(log_buffer) |
| |
| url = f"https://generativelanguage.googleapis.com/v1beta/models/{MODEL_NAME}:generateContent?key={API_KEY}" |
| |
| |
| payload = { |
| "contents": [{ |
| "parts": [{"text": prompt}] |
| }], |
| "generationConfig": { |
| "temperature": 0.1, |
| "maxOutputTokens": 10 |
| } |
| } |
| |
| try: |
| req = urllib.request.Request(url, data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"}) |
| with urllib.request.urlopen(req, timeout=15) as r: |
| response = json.loads(r.read()) |
| |
| decision = response["candidates"][0]["content"]["parts"][0]["text"].strip().upper() |
| |
| if "LOWER_LR" in decision: return "LOWER_LR" |
| if "RESTART" in decision: return "RESTART" |
| return "OK" |
| except Exception as e: |
| print(f"[Agent] Gemini APIエラー: {e}") |
| return "OK" |
|
|
| def main(): |
| print(f"[*] Gemini Terminal Agent 起動成功 (Model: {MODEL_NAME})") |
| print(f"[*] 学習プロセスを起動中...") |
| |
| |
| cmd = ["python", "-u", "train_rokan.py", "--batch_size", "2"] |
| |
| while True: |
| print(f"\n[Agent] 訓練開始: {' '.join(cmd)}") |
| process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) |
| log_buffer = [] |
| |
| try: |
| for line in process.stdout: |
| line = line.strip() |
| if not line: continue |
| print(line) |
| |
| if "Loss" in line or "Saved:" in line: |
| log_buffer.append(line) |
| |
| |
| if "Saved:" in line and len(log_buffer) > 5: |
| decision = analyze_logs_with_llm(log_buffer[-30:]) |
| if decision == "LOWER_LR": |
| print(f"[Agent] Geminiの判定: {decision} (学習率を下げて再開します)") |
| process.terminate() |
| if "--gate_lr" not in cmd: |
| cmd.extend(["--gate_lr", "5e-4"]) |
| break |
| elif decision == "RESTART": |
| print(f"[Agent] Geminiの判定: {decision} (異常検知につき再起動します)") |
| process.terminate() |
| time.sleep(5) |
| break |
| else: |
| print(f"[Agent] Geminiの判定: {decision} (順調です)") |
| log_buffer = [] |
| |
| except KeyboardInterrupt: |
| print("\n[Agent] ユーザーによる中断。プロセスを終了します。") |
| process.terminate() |
| sys.exit(0) |
| |
| process.wait() |
| if process.returncode != 0 and process.returncode is not None: |
| print(f"[Agent] 訓練プロセスが終了しました (Code: {process.returncode})。10秒後に再起動を試みます。") |
| time.sleep(10) |
|
|
| if __name__ == "__main__": |
| main() |
|
|