Upload layer5_state_management.py with huggingface_hub
Browse files- layer5_state_management.py +131 -0
layer5_state_management.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
import json
|
| 3 |
+
from collections import Counter
|
| 4 |
+
|
| 5 |
+
class LayerResetManager:
|
| 6 |
+
"""
|
| 7 |
+
Layer 24(LLMの最終層)のリセットを管理するクラス(概念的な実装)。
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self, llm_model_mock):
|
| 10 |
+
self.llm = llm_model_mock
|
| 11 |
+
self.reset_history = []
|
| 12 |
+
|
| 13 |
+
def reset_layer24_for_new_turn(self):
|
| 14 |
+
"""
|
| 15 |
+
新しいターンのためにLayer 24のKVキャッシュをクリアする操作をシミュレートします。
|
| 16 |
+
"""
|
| 17 |
+
# 実際のLLMモデルの `clear_kv_cache` メソッドを呼び出す想定
|
| 18 |
+
if hasattr(self.llm, 'clear_kv_cache') and callable(self.llm.clear_kv_cache):
|
| 19 |
+
self.llm.clear_kv_cache(layer=24)
|
| 20 |
+
|
| 21 |
+
print("Simulating: KV cache for Layer 24 has been reset.")
|
| 22 |
+
self.reset_history.append({
|
| 23 |
+
"timestamp": datetime.now().isoformat(),
|
| 24 |
+
"action": "layer24_reset"
|
| 25 |
+
})
|
| 26 |
+
|
| 27 |
+
class ExternalState:
|
| 28 |
+
"""
|
| 29 |
+
セッションの外部状態を管理し、メモリ使用量を約10KBに制限します。
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, max_size_bytes=10240):
|
| 32 |
+
self.conversation_summary = []
|
| 33 |
+
self.coordinate_trail = []
|
| 34 |
+
self.max_size_bytes = max_size_bytes
|
| 35 |
+
self.current_size = 0
|
| 36 |
+
|
| 37 |
+
def _extract_keywords(self, text: str, max_words=3) -> list:
|
| 38 |
+
"""テキストから簡易的にキーワードを抽出する。"""
|
| 39 |
+
words = re.findall(r'\b\w+\b', text.lower())
|
| 40 |
+
# 簡単なストップワード除去
|
| 41 |
+
stopwords = {"です", "ます", "が", "は", "を", "に", "と", "の"}
|
| 42 |
+
words = [word for word in words if word not in stopwords]
|
| 43 |
+
return [word for word, freq in Counter(words).most_common(max_words)]
|
| 44 |
+
|
| 45 |
+
def _compress_turn(self, user_input: str, llm_response: str, db_coords: list) -> dict:
|
| 46 |
+
"""ターン情報を要約・圧縮する。"""
|
| 47 |
+
return {
|
| 48 |
+
"user_keywords": self._extract_keywords(user_input),
|
| 49 |
+
"response_summary": llm_response[:100] + "..." if len(llm_response) > 100 else llm_response,
|
| 50 |
+
"db_coords_used": db_coords[:3],
|
| 51 |
+
"turn_length": len(llm_response)
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
def _compress_old_turns(self):
|
| 55 |
+
"""古いターンを圧縮(この実装では単純に削除)してサイズを管理する。"""
|
| 56 |
+
if len(self.conversation_summary) > 20:
|
| 57 |
+
print("State size limit exceeded, compressing old turns...")
|
| 58 |
+
# 最も古いターンから削除していく
|
| 59 |
+
while self.current_size > self.max_size_bytes and self.conversation_summary:
|
| 60 |
+
removed_turn = self.conversation_summary.pop(0)
|
| 61 |
+
self.current_size -= removed_turn.get("size_bytes", 0)
|
| 62 |
+
|
| 63 |
+
def add_turn_summary(self, turn_num: int, user_input: str, llm_response: str, db_coords: list):
|
| 64 |
+
"""ターンの要約をExternalStateに追加する。"""
|
| 65 |
+
summary = self._compress_turn(user_input, llm_response, db_coords)
|
| 66 |
+
summary_size = len(json.dumps(summary, ensure_ascii=False).encode('utf-8'))
|
| 67 |
+
|
| 68 |
+
turn_data = {
|
| 69 |
+
"turn": turn_num,
|
| 70 |
+
"summary": summary,
|
| 71 |
+
"size_bytes": summary_size
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
self.conversation_summary.append(turn_data)
|
| 75 |
+
self.current_size += summary_size
|
| 76 |
+
if db_coords:
|
| 77 |
+
self.coordinate_trail.extend(db_coords)
|
| 78 |
+
|
| 79 |
+
self._compress_old_turns()
|
| 80 |
+
|
| 81 |
+
def _extract_key_coordinates(self) -> list:
|
| 82 |
+
"""座標軌跡から主要な座標を抽出する。"""
|
| 83 |
+
if not self.coordinate_trail:
|
| 84 |
+
return []
|
| 85 |
+
coord_freq = Counter([tuple(c) for c in self.coordinate_trail])
|
| 86 |
+
return [list(coord) for coord, count in coord_freq.most_common(5)]
|
| 87 |
+
|
| 88 |
+
def get_context_for_next_turn(self) -> dict:
|
| 89 |
+
"""次のターンのLLM推論用に文脈を構築する。"""
|
| 90 |
+
recent_turns = self.conversation_summary[-3:] # 直近3ターン
|
| 91 |
+
return {
|
| 92 |
+
"recent_conversation_summary": [
|
| 93 |
+
f"Turn {t['turn']}: User asked about '{', '.join(t['summary']['user_keywords'])}' -> Response: '{t['summary']['response_summary']}'"
|
| 94 |
+
for t in recent_turns
|
| 95 |
+
],
|
| 96 |
+
"key_coordinates": self._extract_key_coordinates(),
|
| 97 |
+
"context_size_bytes": self.current_size
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# --- 使用例 ---
|
| 101 |
+
if __name__ == '__main__':
|
| 102 |
+
import re
|
| 103 |
+
|
| 104 |
+
# LayerResetManager のデモ
|
| 105 |
+
class MockLLM:
|
| 106 |
+
def clear_kv_cache(self, layer):
|
| 107 |
+
pass #何もしない
|
| 108 |
+
|
| 109 |
+
print("--- LayerResetManager Demo ---")
|
| 110 |
+
reset_manager = LayerResetManager(MockLLM())
|
| 111 |
+
reset_manager.reset_layer24_for_new_turn()
|
| 112 |
+
|
| 113 |
+
# ExternalState のデモ
|
| 114 |
+
print("\n--- ExternalState Demo ---")
|
| 115 |
+
external_state = ExternalState(max_size_bytes=500) # デモ用にサイズを小さく設定
|
| 116 |
+
|
| 117 |
+
# ターンを追加
|
| 118 |
+
external_state.add_turn_summary(1, "心筋梗塞の原因は?", "冠動脈の閉塞が主な原因です...", [[28, 55, 15]])
|
| 119 |
+
external_state.add_turn_summary(2, "治療法は?", "カテーテル治療やバイパス手術があります...", [[28, 35, 20]])
|
| 120 |
+
external_state.add_turn_summary(3, "予防について", "食生活の改善、運動、禁煙が重要です...", [[28, 35, 85]])
|
| 121 |
+
|
| 122 |
+
print(f"\nCurrent state size: {external_state.current_size} bytes")
|
| 123 |
+
print("Context for next turn:", json.dumps(external_state.get_context_for_next_turn(), indent=2, ensure_ascii=False))
|
| 124 |
+
|
| 125 |
+
# さらにターンを追加して圧縮(削除)をトリガー
|
| 126 |
+
print("\nAdding more turns to trigger compression...")
|
| 127 |
+
for i in range(4, 10):
|
| 128 |
+
external_state.add_turn_summary(i, f"質問{i}", f"回答{i}...", [[i, i, i]])
|
| 129 |
+
|
| 130 |
+
print(f"Final state size after compression: {external_state.current_size} bytes")
|
| 131 |
+
print("Final context:", json.dumps(external_state.get_context_for_next_turn(), indent=2, ensure_ascii=False))
|