kofdai commited on
Commit
ad74a4b
·
verified ·
1 Parent(s): 5a37439

Upload layer5_state_management.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. layer5_state_management.py +131 -0
layer5_state_management.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import json
3
+ from collections import Counter
4
+
5
+ class LayerResetManager:
6
+ """
7
+ Layer 24(LLMの最終層)のリセットを管理するクラス(概念的な実装)。
8
+ """
9
+ def __init__(self, llm_model_mock):
10
+ self.llm = llm_model_mock
11
+ self.reset_history = []
12
+
13
+ def reset_layer24_for_new_turn(self):
14
+ """
15
+ 新しいターンのためにLayer 24のKVキャッシュをクリアする操作をシミュレートします。
16
+ """
17
+ # 実際のLLMモデルの `clear_kv_cache` メソッドを呼び出す想定
18
+ if hasattr(self.llm, 'clear_kv_cache') and callable(self.llm.clear_kv_cache):
19
+ self.llm.clear_kv_cache(layer=24)
20
+
21
+ print("Simulating: KV cache for Layer 24 has been reset.")
22
+ self.reset_history.append({
23
+ "timestamp": datetime.now().isoformat(),
24
+ "action": "layer24_reset"
25
+ })
26
+
27
+ class ExternalState:
28
+ """
29
+ セッションの外部状態を管理し、メモリ使用量を約10KBに制限します。
30
+ """
31
+ def __init__(self, max_size_bytes=10240):
32
+ self.conversation_summary = []
33
+ self.coordinate_trail = []
34
+ self.max_size_bytes = max_size_bytes
35
+ self.current_size = 0
36
+
37
+ def _extract_keywords(self, text: str, max_words=3) -> list:
38
+ """テキストから簡易的にキーワードを抽出する。"""
39
+ words = re.findall(r'\b\w+\b', text.lower())
40
+ # 簡単なストップワード除去
41
+ stopwords = {"です", "ます", "が", "は", "を", "に", "と", "の"}
42
+ words = [word for word in words if word not in stopwords]
43
+ return [word for word, freq in Counter(words).most_common(max_words)]
44
+
45
+ def _compress_turn(self, user_input: str, llm_response: str, db_coords: list) -> dict:
46
+ """ターン情報を要約・圧縮する。"""
47
+ return {
48
+ "user_keywords": self._extract_keywords(user_input),
49
+ "response_summary": llm_response[:100] + "..." if len(llm_response) > 100 else llm_response,
50
+ "db_coords_used": db_coords[:3],
51
+ "turn_length": len(llm_response)
52
+ }
53
+
54
+ def _compress_old_turns(self):
55
+ """古いターンを圧縮(この実装では単純に削除)してサイズを管理する。"""
56
+ if len(self.conversation_summary) > 20:
57
+ print("State size limit exceeded, compressing old turns...")
58
+ # 最も古いターンから削除していく
59
+ while self.current_size > self.max_size_bytes and self.conversation_summary:
60
+ removed_turn = self.conversation_summary.pop(0)
61
+ self.current_size -= removed_turn.get("size_bytes", 0)
62
+
63
+ def add_turn_summary(self, turn_num: int, user_input: str, llm_response: str, db_coords: list):
64
+ """ターンの要約をExternalStateに追加する。"""
65
+ summary = self._compress_turn(user_input, llm_response, db_coords)
66
+ summary_size = len(json.dumps(summary, ensure_ascii=False).encode('utf-8'))
67
+
68
+ turn_data = {
69
+ "turn": turn_num,
70
+ "summary": summary,
71
+ "size_bytes": summary_size
72
+ }
73
+
74
+ self.conversation_summary.append(turn_data)
75
+ self.current_size += summary_size
76
+ if db_coords:
77
+ self.coordinate_trail.extend(db_coords)
78
+
79
+ self._compress_old_turns()
80
+
81
+ def _extract_key_coordinates(self) -> list:
82
+ """座標軌跡から主要な座標を抽出する。"""
83
+ if not self.coordinate_trail:
84
+ return []
85
+ coord_freq = Counter([tuple(c) for c in self.coordinate_trail])
86
+ return [list(coord) for coord, count in coord_freq.most_common(5)]
87
+
88
+ def get_context_for_next_turn(self) -> dict:
89
+ """次のターンのLLM推論用に文脈を構築する。"""
90
+ recent_turns = self.conversation_summary[-3:] # 直近3ターン
91
+ return {
92
+ "recent_conversation_summary": [
93
+ f"Turn {t['turn']}: User asked about '{', '.join(t['summary']['user_keywords'])}' -> Response: '{t['summary']['response_summary']}'"
94
+ for t in recent_turns
95
+ ],
96
+ "key_coordinates": self._extract_key_coordinates(),
97
+ "context_size_bytes": self.current_size
98
+ }
99
+
100
+ # --- 使用例 ---
101
+ if __name__ == '__main__':
102
+ import re
103
+
104
+ # LayerResetManager のデモ
105
+ class MockLLM:
106
+ def clear_kv_cache(self, layer):
107
+ pass #何もしない
108
+
109
+ print("--- LayerResetManager Demo ---")
110
+ reset_manager = LayerResetManager(MockLLM())
111
+ reset_manager.reset_layer24_for_new_turn()
112
+
113
+ # ExternalState のデモ
114
+ print("\n--- ExternalState Demo ---")
115
+ external_state = ExternalState(max_size_bytes=500) # デモ用にサイズを小さく設定
116
+
117
+ # ターンを追加
118
+ external_state.add_turn_summary(1, "心筋梗塞の原因は?", "冠動脈の閉塞が主な原因です...", [[28, 55, 15]])
119
+ external_state.add_turn_summary(2, "治療法は?", "カテーテル治療やバイパス手術があります...", [[28, 35, 20]])
120
+ external_state.add_turn_summary(3, "予防について", "食生活の改善、運動、禁煙が重要です...", [[28, 35, 85]])
121
+
122
+ print(f"\nCurrent state size: {external_state.current_size} bytes")
123
+ print("Context for next turn:", json.dumps(external_state.get_context_for_next_turn(), indent=2, ensure_ascii=False))
124
+
125
+ # さらにターンを追加して圧縮(削除)をトリガー
126
+ print("\nAdding more turns to trigger compression...")
127
+ for i in range(4, 10):
128
+ external_state.add_turn_summary(i, f"質問{i}", f"回答{i}...", [[i, i, i]])
129
+
130
+ print(f"Final state size after compression: {external_state.current_size} bytes")
131
+ print("Final context:", json.dumps(external_state.get_context_for_next_turn(), indent=2, ensure_ascii=False))