pragunk commited on
Commit
dfea88e
·
verified ·
1 Parent(s): 0fdc20b

Create adaptive_cache/env.py

Browse files
Files changed (1) hide show
  1. adaptive_cache/env.py +81 -0
adaptive_cache/env.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Dict, Any, Tuple
3
+ from .simulator import CacheSimulator
4
+ from .workloads import generate_easy_task, generate_medium_task, generate_hard_task
5
+
6
+ class Observation(BaseModel):
7
+ incoming_request: int = Field(description="The ID of the data item being requested.")
8
+ cache_state: List[int] = Field(description="Current items in the cache. -1 means empty.")
9
+ idle_times: List[int] = Field(description="Time steps since each cache slot was last accessed.")
10
+
11
+ class Action(BaseModel):
12
+ evict_index: int = Field(description="The index (0 to capacity-1) of the cache slot to evict.")
13
+
14
+ class AdaptiveCacheEnv:
15
+ def __init__(self, task_level: str = "easy", capacity: int = 10):
16
+ self.capacity = capacity
17
+ self.task_level = task_level
18
+ self.sim = CacheSimulator(capacity)
19
+
20
+ if task_level == "easy":
21
+ self.workload = generate_easy_task()
22
+ elif task_level == "medium":
23
+ self.workload = generate_medium_task(cache_size=capacity)
24
+ else:
25
+ self.workload = generate_hard_task()
26
+
27
+ self.step_count = 0
28
+ self.hits = 0
29
+
30
+ def reset(self) -> Observation:
31
+ self.sim = CacheSimulator(self.capacity)
32
+ self.step_count = 0
33
+ self.hits = 0
34
+ return self.state()
35
+
36
+ def state(self) -> Observation:
37
+ # Safe check for the terminal state to prevent IndexError
38
+ if self.step_count >= len(self.workload):
39
+ current_item = -1 # Simulation is over, no more incoming requests
40
+ else:
41
+ current_item = self.workload[self.step_count]
42
+
43
+ idle_times = [(self.sim.current_time - t) if t > 0 else 0 for t in self.sim.last_access_time]
44
+ return Observation(
45
+ incoming_request=current_item,
46
+ cache_state=self.sim.cache.tolist(),
47
+ idle_times=idle_times
48
+ )
49
+
50
+ def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
51
+ # 1. Apply Action (Evict and Insert)
52
+ current_item = self.workload[self.step_count]
53
+ self.sim.evict_and_insert(action.evict_index, current_item)
54
+
55
+ # 2. Advance time strictly by 1 step
56
+ self.step_count += 1
57
+
58
+ # 3. Check Episode Boundary
59
+ done = self.step_count >= len(self.workload)
60
+ reward = 0.0
61
+
62
+ if done:
63
+ final_score = self.hits / max(1, len(self.workload))
64
+ return self.state(), reward, True, {"score": final_score}
65
+
66
+ # 4. Evaluate the *next* state strictly without fast-forwarding
67
+ next_item = self.workload[self.step_count]
68
+ is_hit = self.sim.request_item(next_item)
69
+
70
+ if is_hit:
71
+ reward = 1.0
72
+ self.hits += 1
73
+ # If it's a hit, the agent will see this in the next observation
74
+ # and can essentially choose a "safe" eviction slot that doesn't hurt.
75
+ else:
76
+ reward = -1.0
77
+
78
+ current_score = self.hits / max(1, self.step_count)
79
+ info = {"score": current_score, "hits": self.hits, "steps": self.step_count}
80
+
81
+ return self.state(), reward, done, info