Spaces:
Sleeping
Sleeping
Create adaptive_cache/env.py
Browse files- adaptive_cache/env.py +81 -0
adaptive_cache/env.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List, Dict, Any, Tuple
|
| 3 |
+
from .simulator import CacheSimulator
|
| 4 |
+
from .workloads import generate_easy_task, generate_medium_task, generate_hard_task
|
| 5 |
+
|
| 6 |
+
class Observation(BaseModel):
|
| 7 |
+
incoming_request: int = Field(description="The ID of the data item being requested.")
|
| 8 |
+
cache_state: List[int] = Field(description="Current items in the cache. -1 means empty.")
|
| 9 |
+
idle_times: List[int] = Field(description="Time steps since each cache slot was last accessed.")
|
| 10 |
+
|
| 11 |
+
class Action(BaseModel):
|
| 12 |
+
evict_index: int = Field(description="The index (0 to capacity-1) of the cache slot to evict.")
|
| 13 |
+
|
| 14 |
+
class AdaptiveCacheEnv:
|
| 15 |
+
def __init__(self, task_level: str = "easy", capacity: int = 10):
|
| 16 |
+
self.capacity = capacity
|
| 17 |
+
self.task_level = task_level
|
| 18 |
+
self.sim = CacheSimulator(capacity)
|
| 19 |
+
|
| 20 |
+
if task_level == "easy":
|
| 21 |
+
self.workload = generate_easy_task()
|
| 22 |
+
elif task_level == "medium":
|
| 23 |
+
self.workload = generate_medium_task(cache_size=capacity)
|
| 24 |
+
else:
|
| 25 |
+
self.workload = generate_hard_task()
|
| 26 |
+
|
| 27 |
+
self.step_count = 0
|
| 28 |
+
self.hits = 0
|
| 29 |
+
|
| 30 |
+
def reset(self) -> Observation:
|
| 31 |
+
self.sim = CacheSimulator(self.capacity)
|
| 32 |
+
self.step_count = 0
|
| 33 |
+
self.hits = 0
|
| 34 |
+
return self.state()
|
| 35 |
+
|
| 36 |
+
def state(self) -> Observation:
|
| 37 |
+
# Safe check for the terminal state to prevent IndexError
|
| 38 |
+
if self.step_count >= len(self.workload):
|
| 39 |
+
current_item = -1 # Simulation is over, no more incoming requests
|
| 40 |
+
else:
|
| 41 |
+
current_item = self.workload[self.step_count]
|
| 42 |
+
|
| 43 |
+
idle_times = [(self.sim.current_time - t) if t > 0 else 0 for t in self.sim.last_access_time]
|
| 44 |
+
return Observation(
|
| 45 |
+
incoming_request=current_item,
|
| 46 |
+
cache_state=self.sim.cache.tolist(),
|
| 47 |
+
idle_times=idle_times
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
|
| 51 |
+
# 1. Apply Action (Evict and Insert)
|
| 52 |
+
current_item = self.workload[self.step_count]
|
| 53 |
+
self.sim.evict_and_insert(action.evict_index, current_item)
|
| 54 |
+
|
| 55 |
+
# 2. Advance time strictly by 1 step
|
| 56 |
+
self.step_count += 1
|
| 57 |
+
|
| 58 |
+
# 3. Check Episode Boundary
|
| 59 |
+
done = self.step_count >= len(self.workload)
|
| 60 |
+
reward = 0.0
|
| 61 |
+
|
| 62 |
+
if done:
|
| 63 |
+
final_score = self.hits / max(1, len(self.workload))
|
| 64 |
+
return self.state(), reward, True, {"score": final_score}
|
| 65 |
+
|
| 66 |
+
# 4. Evaluate the *next* state strictly without fast-forwarding
|
| 67 |
+
next_item = self.workload[self.step_count]
|
| 68 |
+
is_hit = self.sim.request_item(next_item)
|
| 69 |
+
|
| 70 |
+
if is_hit:
|
| 71 |
+
reward = 1.0
|
| 72 |
+
self.hits += 1
|
| 73 |
+
# If it's a hit, the agent will see this in the next observation
|
| 74 |
+
# and can essentially choose a "safe" eviction slot that doesn't hurt.
|
| 75 |
+
else:
|
| 76 |
+
reward = -1.0
|
| 77 |
+
|
| 78 |
+
current_score = self.hits / max(1, self.step_count)
|
| 79 |
+
info = {"score": current_score, "hits": self.hits, "steps": self.step_count}
|
| 80 |
+
|
| 81 |
+
return self.state(), reward, done, info
|