prithic07 commited on
Commit
6cad4bb
·
1 Parent(s): 222f8ce

Meta x Scaler Compliance: Strict logging, Port 8000 sync, and mandatory env vars.

Browse files
Dockerfile CHANGED
@@ -7,23 +7,8 @@ ENV PYTHONPATH=/app
7
  WORKDIR /app
8
 
9
  # Install system dependencies
10
- RUN apt-get update && apt-get install -y \
11
- build-essential \
12
- curl \
13
- && rm -rf /var/lib/apt/lists/*
14
-
15
- # Optimize for 2 vCPU and 8GB RAM
16
- # Copy and install Python dependencies separately for layer caching
17
  COPY requirements.txt .
18
  RUN pip install --no-cache-dir -r requirements.txt
19
 
20
- # Copy all source files
21
- COPY context_pruning_env ./context_pruning_env
22
- COPY inference.py .
23
- COPY openenv.yaml .
24
-
25
- # Expose the standard OpenEnv port
26
- EXPOSE 7860
27
-
28
- # FastAPI app entrypoint
29
- CMD ["uvicorn", "context_pruning_env.server.app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "2"]
 
7
  WORKDIR /app
8
 
9
  # Install system dependencies
10
+ # Install dependencies
 
 
 
 
 
 
11
  COPY requirements.txt .
12
  RUN pip install --no-cache-dir -r requirements.txt
13
 
14
+ CMD ["uvicorn", "context_pruning_env.server.app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,62 +1,80 @@
1
- # Adaptive Context Optimization Agent (ContextPrune)
2
 
3
- **ContextPrune** is a specialized Reinforcement Learning (RL) environment for optimizing retrieved context in RAG systems, designed for the **Meta x Scaler Hackathon**.
4
-
5
- > “ContextPrune reduces noise and tokens in RAG pipelines while preserving answer quality.”
6
 
7
  ---
8
 
9
- ## 🚀 Hackathon Tasks
 
10
 
11
- | Task Name | Difficulty | Objective |
12
- | :--- | :--- | :--- |
13
- | `noise_filter` | **Easy** | Identify and prune 4 random noise chunks while keeping the gold chunk. |
14
- | `deduplication` | **Medium** | Recognize duplicate gold chunks and prune exactly one of them. |
15
- | `sentence_distillation` | **Hard** | Sharp pruning of context to isolate the core sentence containing the answer. |
16
 
17
- ## 🛠️ Installation & Setup
18
 
19
- ```bash
20
- # Install dependencies
21
- pip install -r requirements.txt
22
 
23
- # Set your Gemini API Key
24
- export GOOGLE_API_KEY=your_key_here
 
 
25
 
26
- # Verify the environment and task logic
27
- pytest test_tasks.py
28
- ```
 
 
 
29
 
30
- ## 🧠 Environment API
31
 
32
- The environment strictly follows the **openenv-core** interface (`reset`, `step`, `state`):
33
 
34
- - **Reset**: `env.reset(task_name=...)`
35
- - **Step**: `env.step(ContextAction(mask=...))`
36
- - **Observation**: Question + 5 context chunks as strings.
37
- - **Reward**: Programmatic 0.0–1.0 score based on accuracy and efficiency.
 
38
 
39
- ## 🐳 Docker & OpenEnv
40
 
41
- The project is containerized for deployment on 2 vCPU and 8GB RAM environments:
 
 
 
 
42
 
 
 
 
 
 
43
  ```bash
44
- # Build the production container
45
- docker build -t context-prune .
46
 
47
- # Run the OpenEnv server
48
- docker run -p 7860:7860 context-prune
49
- ```
50
 
51
- ## 📜 Standardized Inference
 
 
52
 
53
- The [inference.py](file:///d:/Projects/RAG/inference.py) script emits logs with mandatory tags for automated evaluation:
 
 
 
54
 
55
- ```text
56
- <OBSERVATION>{...}</OBSERVATION>
57
- <ACTION>{...}</ACTION>
58
- <REWARD>{...}</REWARD>
59
  ```
60
 
 
 
 
 
61
  ---
62
- *Powered by OpenEnv*
 
1
+ # ContextPrune: Adaptive Context Optimization Environment
2
 
3
+ **ContextPrune** is a Meta x Scaler Hackathon compliant reinforcement learning environment designed for Phase 1: Automated Validation. It focuses on the critical task of context pruning for RAG pipelines, reducing noise and token counts while strictly preserving answer faithfulness.
 
 
4
 
5
  ---
6
 
7
+ ## 🌍 Environment Description
8
+ ContextPrune implements the **OpenEnv Spec**, providing a standardized interface for RL agents to optimize retrieved contexts. The environment presents a query and multiple context chunks (from SQuAD or synthetic noise) where the agent must decide which chunks to keep and which to prune using a binary mask.
9
 
10
+ ### Resource Constraints
11
+ - **vCPU**: 2
12
+ - **RAM**: 8GB
13
+ - **Runtime**: Python 3.10+
14
+ - **Port**: 8000 (OpenEnv Server)
15
 
16
+ ---
17
 
18
+ ## 🎮 Action & Observation Spaces
 
 
19
 
20
+ ### Action Space (ContextAction)
21
+ - **Type**: Binary Mask (`List[int]`)
22
+ - **Values**: `1` (Keep), `0` (Prune)
23
+ - **Constraint**: Must match the number of chunks in the current observation.
24
 
25
+ ### Observation Space (ContextObservation)
26
+ - **question**: The user query to be answered.
27
+ - **chunks**: A list of text strings representing the retrieved context.
28
+ - **initial_token_count**: The total token count before optimization.
29
+ - **current_token_count**: Cumulative tokens of the currently selected chunks.
30
+ - **task_name**: The identifier for the current pruning task.
31
 
32
+ ---
33
 
34
+ ## 🏆 Task Descriptions
35
 
36
+ | Task ID | Name | Difficulty | Scoring Logic |
37
+ | :--- | :--- | :--- | :--- |
38
+ | **01** | `noise_purge` | **Easy** | 0.0 or 1.0. Perfect score if all noise is deleted and the answer is kept. |
39
+ | **02** | `dedupe_arena` | **Medium** | 1.0 if word count is reduced by >50% while preserving the answer. |
40
+ | **03** | `signal_extract` | **Hard** | $1 - (FinalTokens/InitialTokens)$. Score scales with compression ratio. |
41
 
42
+ ---
43
 
44
+ ## 📈 Reward Function (Trajectory Signals)
45
+ The environment emits rewards based on the agent's efficiency and accuracy:
46
+ - **Efficiency**: `+0.1` for every irrelevant chunk or duplicate correctly pruned.
47
+ - **Accuracy**: `+0.7` bonus at the end of the trajectory if the "Gold Chunk" is preserved.
48
+ - **Death Penalty**: `-1.0` and immediate `done=True` if the agent prunes the Gold Chunk (Information Loss).
49
 
50
+ ---
51
+
52
+ ## 🛠️ Setup Instructions
53
+
54
+ ### 1. Local Development
55
  ```bash
56
+ # Install dependencies
57
+ pip install -r requirements.txt
58
 
59
+ # Configure API (Optional for testing)
60
+ echo "GOOGLE_API_KEY=your_key" > .env
 
61
 
62
+ # Run Inference Evaluation
63
+ python inference.py
64
+ ```
65
 
66
+ ### 2. Docker Deployment
67
+ ```bash
68
+ # Build the standardized image
69
+ docker build -t contextprune .
70
 
71
+ # Start the environment server
72
+ docker run -p 8000:8000 contextprune
 
 
73
  ```
74
 
75
+ ### 3. Inference Logging
76
+ Mandatory logs are emitted in the following format for the Hackathon Evaluator:
77
+ `task=<name> env=contextprune model=<model> step=<n> action=<str> reward=<0.00> done=<bool> score=<score> rewards=<r1,r2...>`
78
+
79
  ---
80
+ *Built for the Meta x Scaler Hackathon 2026*
app_ui.py CHANGED
@@ -30,16 +30,20 @@ async def call_gemini(prompt: str, model_name: str = "gemini-1.5-flash") -> str:
30
  except Exception as e:
31
  return f"ERROR: {str(e)}"
32
 
33
- def chunk_text(text: str, max_chunks: int = 5) -> List[str]:
34
  """Split text into manageable chunks (paragraphs or sentences)."""
35
- # Split by double newline first
36
- chunks = [c.strip() for c in re.split(r'\n\s*\n', text) if c.strip()]
37
- if len(chunks) < 2:
38
- # Split by sentence if only one paragraph
39
- chunks = [c.strip() for c in re.split(r'(?<=[.!?])\s+', text) if c.strip()]
40
 
41
- # Simple limit to 5-10 chunks for the demo
42
- return chunks[:10]
 
 
 
 
 
 
 
43
 
44
  async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]:
45
  """
@@ -53,26 +57,49 @@ async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]:
53
  # Prompt for selection
54
  selection_prompt = (
55
  f"Query: {query}\n\n"
56
- "Below are several context chunks. Identify which are RELEVANT and which are NOISE or DUPLICATES. "
57
- "Output a JSON list of indices (0-indexed) of the chunks to KEEP.\n"
58
- "Example output: [0, 2, 3]\n\n"
59
  "Chunks:\n"
60
  )
61
  for i, c in enumerate(chunks):
62
  selection_prompt += f"Chunk {i}: {c}\n\n"
63
 
64
  raw_response = await call_gemini(selection_prompt)
 
65
 
66
- # Extract indices
67
- match = re.search(r"\[([\d\s,]+)\]", raw_response)
68
- if match:
69
- try:
70
- indices = json.loads(f"[{match.group(1)}]")
71
- kept_chunks = [chunks[i] for i in indices if i < len(chunks)]
72
- except:
73
- kept_chunks = chunks # Fallback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  else:
75
- kept_chunks = chunks # Fallback
 
 
76
 
77
  optimized_text = " ".join(kept_chunks)
78
 
 
30
  except Exception as e:
31
  return f"ERROR: {str(e)}"
32
 
33
+ def chunk_text(text: str, max_chunks: int = 10) -> List[str]:
34
  """Split text into manageable chunks (paragraphs or sentences)."""
35
+ # 1. First split by double newlines (paragraphs)
36
+ initial_chunks = [c.strip() for c in re.split(r'\n\s*\n', text) if c.strip()]
 
 
 
37
 
38
+ final_chunks = []
39
+ # 2. If paragraphs are too few or long, split them into sentences
40
+ for chunk in initial_chunks:
41
+ # Split by sentence markers [.!?] followed by space or newline
42
+ sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+|\n', chunk) if s.strip()]
43
+ final_chunks.extend(sentences)
44
+
45
+ # Simple limit to 10 chunks to avoid overwhelming the prompt
46
+ return final_chunks[:max_chunks]
47
 
48
  async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]:
49
  """
 
57
  # Prompt for selection
58
  selection_prompt = (
59
  f"Query: {query}\n\n"
60
+ "TASK: Select indices of context chunks that are directly relevant to the query. "
61
+ "Remove noise, random facts, and duplicates. "
62
+ "OUTPUT: Output ONLY the list of indices as a JSON array like [0, 2, 4]. No explanations.\n\n"
63
  "Chunks:\n"
64
  )
65
  for i, c in enumerate(chunks):
66
  selection_prompt += f"Chunk {i}: {c}\n\n"
67
 
68
  raw_response = await call_gemini(selection_prompt)
69
+ print(f"DEBUG: Gemini Response: {raw_response}")
70
 
71
+ from context_pruning_env.graders import (
72
+ grade_noise_purge,
73
+ grade_dedupe_arena,
74
+ grade_signal_extract
75
+ )
76
+
77
+ # Ultra-robust extraction
78
+ indices = []
79
+ try:
80
+ match = re.search(r"\[([\d\s,]+)\]", raw_response)
81
+ if match:
82
+ # Found a bracketed list of numbers
83
+ content = match.group(0) # e.g. "[0, 1, 2]"
84
+ indices = json.loads(content)
85
+ else:
86
+ # Try finding any numbers in the response if no brackets
87
+ nums = re.findall(r"\d+", raw_response)
88
+ indices = [int(n) for n in nums]
89
+
90
+ # Clean up: only valid unique indices
91
+ indices = list(set([int(i) for i in indices if isinstance(i, int) and 0 <= i < len(chunks)]))
92
+ print(f"DEBUG: Successfully extracted indices: {indices}")
93
+ except Exception as e:
94
+ print(f"DEBUG: Extraction Error: {e}")
95
+ indices = []
96
+
97
+ if indices:
98
+ kept_chunks = [chunks[i] for i in sorted(indices)]
99
  else:
100
+ # Fallback to keep everything if AI fails, but message it
101
+ print("DEBUG: Pruning failed, keeping original context.")
102
+ kept_chunks = chunks
103
 
104
  optimized_text = " ".join(kept_chunks)
105
 
context_pruning_env/env.py CHANGED
@@ -12,9 +12,9 @@ from context_pruning_env.models import (
12
  )
13
  from context_pruning_env.utils import SQuADLoader, count_tokens
14
  from context_pruning_env.graders import (
15
- grade_noise_filter,
16
- grade_deduplication,
17
- grade_sentence_distillation
18
  )
19
 
20
  class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningState]):
@@ -31,13 +31,14 @@ class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningSt
31
  self,
32
  seed: Optional[int] = None,
33
  episode_id: Optional[str] = None,
34
- task_name: Optional[str] = "noise_filter",
35
  **kwargs: Any,
36
  ) -> ContextObservation:
37
  """
38
  Starts a new episode with the specified task.
39
  """
40
- question, chunks_data = self.loader.get_episode(task_name or "noise_filter")
 
41
 
42
  chunks = []
43
  total_tokens = 0
@@ -53,7 +54,7 @@ class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningSt
53
 
54
  self._state = PruningState(
55
  episode_id=episode_id or str(uuid4()),
56
- task_name=task_name or "noise_filter",
57
  question=question,
58
  chunks=chunks,
59
  initial_tokens=total_tokens,
@@ -69,7 +70,8 @@ class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningSt
69
  done=self._state.done,
70
  question=self._state.question,
71
  chunks=[c.content for c in self._state.chunks],
72
- token_count=sum(c.tokens for c in self._state.chunks),
 
73
  task_name=self._state.task_name,
74
  message=message
75
  )
@@ -80,34 +82,67 @@ class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningSt
80
  **kwargs: Any,
81
  ) -> ContextObservation:
82
  """
83
- Takes a binary mask and grades according to task rules.
84
  """
85
  if self._state.done:
86
  return self._observe(message="Episode is already done.")
87
 
88
  mask = action.mask
 
 
 
89
 
90
- # 1. Select Grader
91
- if self._state.task_name == "noise_filter":
92
- reward_obj = grade_noise_filter(mask, self._state.chunks)
93
- elif self._state.task_name == "deduplication":
94
- reward_obj = grade_deduplication(mask, self._state.chunks)
95
- elif self._state.task_name == "sentence_distillation":
96
- reward_obj = grade_sentence_distillation(mask, self._state.chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  else:
98
- reward_obj = grade_noise_filter(mask, self._state.chunks)
99
 
100
  self._state.done = True
101
  self._state.step_count += 1
102
 
103
- # Signal reward in observation
104
- obs = self._observe(message=reward_obj.message)
105
- obs.reward = reward_obj.score
106
 
107
- # Add detailed reward to metadata for hackathon transparency
108
  if not obs.metadata:
109
  obs.metadata = {}
110
- obs.metadata["reward_detail"] = reward_obj.model_dump()
 
 
 
 
 
111
 
112
  return obs
113
 
 
12
  )
13
  from context_pruning_env.utils import SQuADLoader, count_tokens
14
  from context_pruning_env.graders import (
15
+ grade_noise_purge,
16
+ grade_dedupe_arena,
17
+ grade_signal_extract
18
  )
19
 
20
  class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningState]):
 
31
  self,
32
  seed: Optional[int] = None,
33
  episode_id: Optional[str] = None,
34
+ task_name: Optional[str] = "noise_purge",
35
  **kwargs: Any,
36
  ) -> ContextObservation:
37
  """
38
  Starts a new episode with the specified task.
39
  """
40
+ task_name = task_name or "noise_purge"
41
+ question, chunks_data = self.loader.get_episode(task_name)
42
 
43
  chunks = []
44
  total_tokens = 0
 
54
 
55
  self._state = PruningState(
56
  episode_id=episode_id or str(uuid4()),
57
+ task_name=task_name,
58
  question=question,
59
  chunks=chunks,
60
  initial_tokens=total_tokens,
 
70
  done=self._state.done,
71
  question=self._state.question,
72
  chunks=[c.content for c in self._state.chunks],
73
+ initial_token_count=self._state.initial_tokens,
74
+ current_token_count=sum(c.tokens for c in self._state.chunks),
75
  task_name=self._state.task_name,
76
  message=message
77
  )
 
82
  **kwargs: Any,
83
  ) -> ContextObservation:
84
  """
85
+ Takes a binary mask and calculates rewards based on trajectory signals.
86
  """
87
  if self._state.done:
88
  return self._observe(message="Episode is already done.")
89
 
90
  mask = action.mask
91
+ if len(mask) != len(self._state.chunks):
92
+ # Pad or truncate mask to match chunk count if agent is misaligned
93
+ mask = (mask + [1] * len(self._state.chunks))[:len(self._state.chunks)]
94
 
95
+ # Trajectory Simulation Logic
96
+ total_reward = 0.0
97
+ efficiency_reward = 0.0
98
+ accuracy_reward = 0.0
99
+ gold_penalty = 0.0
100
+ success = True
101
+
102
+ for i, kept in enumerate(mask):
103
+ chunk = self._state.chunks[i]
104
+ if not kept: # Pruned
105
+ if chunk.is_gold:
106
+ # Critical Failure
107
+ gold_penalty = -1.0
108
+ success = False
109
+ break # Immediate stop
110
+ else:
111
+ # Correctly pruned noise/duplicate
112
+ efficiency_reward += 0.1
113
+ else: # Kept
114
+ pass
115
+
116
+ # Final Accuracy Bonus
117
+ if success:
118
+ accuracy_reward = 0.7
119
+
120
+ total_reward = efficiency_reward + accuracy_reward + gold_penalty
121
+
122
+ # Task Score (Normalized 0.0 to 1.0 for the evaluator)
123
+ if self._state.task_name == "noise_purge":
124
+ score_obj = grade_noise_purge(mask, self._state.chunks)
125
+ elif self._state.task_name == "dedupe_arena":
126
+ score_obj = grade_dedupe_arena(mask, self._state.chunks)
127
+ elif self._state.task_name == "signal_extract":
128
+ score_obj = grade_signal_extract(mask, self._state.chunks)
129
  else:
130
+ score_obj = grade_noise_purge(mask, self._state.chunks)
131
 
132
  self._state.done = True
133
  self._state.step_count += 1
134
 
135
+ obs = self._observe(message=score_obj.message)
136
+ obs.reward = total_reward # Trajectory reward
 
137
 
 
138
  if not obs.metadata:
139
  obs.metadata = {}
140
+ obs.metadata["eval_score"] = score_obj.score # Grader score
141
+ obs.metadata["reward_detail"] = {
142
+ "efficiency": efficiency_reward,
143
+ "accuracy": accuracy_reward,
144
+ "penalty": gold_penalty
145
+ }
146
 
147
  return obs
148
 
context_pruning_env/graders.py CHANGED
@@ -1,56 +1,56 @@
1
  from typing import List
2
  from context_pruning_env.models import ChunkItem, ContextReward
3
 
4
- def grade_noise_filter(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
5
  """
6
- Score: 1.0 if gold kept AND noise pruned.
7
  """
8
  gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
9
  noise_pruned = all(mask[i] == 0 for i in range(len(mask)) if not chunks[i].is_gold)
10
 
11
  if not gold_kept:
12
- return ContextReward(score=0.0, penalty=-1.0, message="Critical Failure: Gold chunk lost.")
13
 
14
  if noise_pruned:
15
- return ContextReward(score=1.0, accuracy_bonus=0.5, efficiency_bonus=0.5, message="Perfect: Gold kept and all noise pruned.")
16
  else:
17
- return ContextReward(score=0.5, accuracy_bonus=0.5, message="Partial: Gold kept but some noise remains.")
18
 
19
- def grade_deduplication(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
20
  """
21
- Score: 1.0 if EXACTLY 1 gold kept AND 0 noise kept.
22
  """
23
- gold_indices = [i for i, c in enumerate(chunks) if c.is_gold]
24
- noise_indices = [i for i, c in enumerate(chunks) if not c.is_gold]
25
 
26
- kept_gold_count = sum(1 for i in gold_indices if mask[i] == 1)
27
- kept_noise_count = sum(1 for i in noise_indices if mask[i] == 1)
28
 
29
- if kept_gold_count == 0:
30
- return ContextReward(score=0.0, message="Critical Failure: All gold chunks lost.")
31
 
32
- if kept_gold_count == 1 and kept_noise_count == 0:
33
- return ContextReward(score=1.0, message="Perfect: Exactly 1 gold kept and 0 noise.")
34
- elif kept_gold_count > 1:
35
- return ContextReward(score=0.4, message="Partial: Duplicates detected.")
36
  else:
37
- return ContextReward(score=0.5, message="Partial: Gold kept but noise remains.")
38
 
39
- def grade_sentence_distillation(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
40
  """
41
- Score: 1.0 if gold kept AND at least 3 noise sentences pruned.
42
  """
43
- gold_index = next(i for i, c in enumerate(chunks) if c.is_gold)
44
- gold_kept = (mask[gold_index] == 1)
45
 
46
- noise_indices = [i for i, c in enumerate(chunks) if not c.is_gold]
47
- pruned_noise_count = sum(1 for i in noise_indices if mask[i] == 0)
48
 
49
  if not gold_kept:
50
- return ContextReward(score=0.0, message="Critical Failure: Summary is missing the answer.")
51
-
52
- if pruned_noise_count >= 3:
53
- return ContextReward(score=1.0, message="Perfect: Sharp distillation achieved.")
54
- else:
55
- efficiency = pruned_noise_count / len(noise_indices) if noise_indices else 1.0
56
- return ContextReward(score=0.2 + 0.5 * efficiency, message="Partial: Distillation incomplete.")
 
 
 
 
1
  from typing import List
2
  from context_pruning_env.models import ChunkItem, ContextReward
3
 
4
+ def grade_noise_purge(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
5
  """
6
+ Easy Task: Score 1.0 if gold kept AND noise pruned.
7
  """
8
  gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
9
  noise_pruned = all(mask[i] == 0 for i in range(len(mask)) if not chunks[i].is_gold)
10
 
11
  if not gold_kept:
12
+ return ContextReward(score=0.0, gold_penalty=-1.0, message="Critical: Gold chunk lost.")
13
 
14
  if noise_pruned:
15
+ return ContextReward(score=1.0, message="Perfect: All noise purged.")
16
  else:
17
+ return ContextReward(score=0.5, message="Partial: Gold kept but noise remains.")
18
 
19
+ def grade_dedupe_arena(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
20
  """
21
+ Medium Task: 1.0 if word count reduced > 50% AND gold kept.
22
  """
23
+ initial_words = sum(len(c.content.split()) for c in chunks)
24
+ final_words = sum(len(chunks[i].content.split()) for i, kept in enumerate(mask) if kept)
25
 
26
+ gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
27
+ reduction = 1.0 - (final_words / initial_words) if initial_words > 0 else 1.0
28
 
29
+ if not gold_kept:
30
+ return ContextReward(score=0.0, message="Critical: Answer lost during deduplication.")
31
 
32
+ if reduction >= 0.5:
33
+ return ContextReward(score=1.0, message=f"Great: {reduction:.1%} word reduction achieved.")
 
 
34
  else:
35
+ return ContextReward(score=0.5, message=f"Partial: Only {reduction:.1%} reduction.")
36
 
37
+ def grade_signal_extract(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
38
  """
39
+ Hard Task: 1 - (FinalTokens/InitialTokens) if gold kept.
40
  """
41
+ initial_tokens = sum(c.tokens for c in chunks)
42
+ final_tokens = sum(chunks[i].tokens for i, kept in enumerate(mask) if kept)
43
 
44
+ gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
 
45
 
46
  if not gold_kept:
47
+ return ContextReward(score=0.0, message="Critical: Signal lost in noise.")
48
+
49
+ reduction_score = 1.0 - (final_tokens / initial_tokens) if initial_tokens > 0 else 0.0
50
+ # Ensure score is at least positive if gold is kept
51
+ final_score = max(0.1, reduction_score)
52
+
53
+ return ContextReward(
54
+ score=final_score,
55
+ message=f"Signal Extracted: {reduction_score:.1%} compression."
56
+ )
context_pruning_env/models.py CHANGED
@@ -5,13 +5,12 @@ from openenv.core.env_server.types import Action, Observation, State
5
 
6
  class ContextAction(Action):
7
  """
8
- Action space: A binary mask of 5 values (1 = keep, 0 = prune).
9
  """
10
  mask: List[int] = Field(
11
  ...,
12
- min_items=5,
13
- max_items=5,
14
- description="Binary mask of 5 integers (0 or 1) indicating which chunks to keep."
15
  )
16
 
17
  class ContextObservation(Observation):
@@ -19,31 +18,32 @@ class ContextObservation(Observation):
19
  Observation provided to the agent.
20
  """
21
  question: str
22
- chunks: List[str] = Field(default_factory=list, description="List of 5 context chunks.")
23
- token_count: int = 0
 
24
  task_name: str = ""
25
  message: str = ""
26
 
27
  class ContextReward(BaseModel):
28
  """
29
- Detailed reward breakdown.
30
  """
31
  score: float = Field(0.0, ge=0.0, le=1.0, description="Overall task score (0 to 1).")
32
- accuracy_bonus: float = 0.0
33
- efficiency_bonus: float = 0.0
34
- penalty: float = 0.0
35
  message: str = ""
36
 
37
  class ChunkItem(BaseModel):
38
  """Internal representation of a context chunk."""
39
  content: str
40
  is_gold: bool = False
41
- is_duplicate: bool = False
42
  tokens: int = 0
 
43
 
44
  class PruningState(State):
45
  """
46
- Internal state of the environment.
47
  """
48
  task_name: str
49
  question: str
 
5
 
6
  class ContextAction(Action):
7
  """
8
+ Action space: A binary mask of N values (1 = keep, 0 = prune).
9
  """
10
  mask: List[int] = Field(
11
  ...,
12
+ min_items=1,
13
+ description="Binary mask of integers (0 or 1) indicating which chunks to keep."
 
14
  )
15
 
16
  class ContextObservation(Observation):
 
18
  Observation provided to the agent.
19
  """
20
  question: str
21
+ chunks: List[str] = Field(default_factory=list, description="Current context chunks.")
22
+ initial_token_count: int = 0
23
+ current_token_count: int = 0
24
  task_name: str = ""
25
  message: str = ""
26
 
27
  class ContextReward(BaseModel):
28
  """
29
+ Detailed reward breakdown for Meta x Scaler audit.
30
  """
31
  score: float = Field(0.0, ge=0.0, le=1.0, description="Overall task score (0 to 1).")
32
+ efficiency_reward: float = 0.0
33
+ accuracy_reward: float = 0.0
34
+ gold_penalty: float = 0.0
35
  message: str = ""
36
 
37
  class ChunkItem(BaseModel):
38
  """Internal representation of a context chunk."""
39
  content: str
40
  is_gold: bool = False
 
41
  tokens: int = 0
42
+ is_duplicate: bool = False
43
 
44
  class PruningState(State):
45
  """
46
+ Internal state for ContextPrune.
47
  """
48
  task_name: str
49
  question: str
context_pruning_env/server/app.py CHANGED
@@ -1,18 +1,20 @@
1
  import os
2
  from openenv.core.env_server.http_server import create_fastapi_app
3
  from context_pruning_env.env import ContextPruningEnv
4
- from context_pruning_env.models import PruningAction, PruningObservation
5
 
6
  app = create_fastapi_app(
7
  ContextPruningEnv,
8
- PruningAction,
9
- PruningObservation,
10
  )
11
 
12
  def main() -> None:
13
  import uvicorn
14
- port = int(os.environ.get("PORT", "7860"))
15
  uvicorn.run(app, host="0.0.0.0", port=port)
16
 
17
  if __name__ == "__main__":
18
  main()
 
 
 
1
  import os
2
  from openenv.core.env_server.http_server import create_fastapi_app
3
  from context_pruning_env.env import ContextPruningEnv
4
+ from context_pruning_env.models import ContextAction, ContextObservation
5
 
6
  app = create_fastapi_app(
7
  ContextPruningEnv,
8
+ ContextAction,
9
+ ContextObservation,
10
  )
11
 
12
  def main() -> None:
13
  import uvicorn
14
+ port = int(os.environ.get("PORT", "8000"))
15
  uvicorn.run(app, host="0.0.0.0", port=port)
16
 
17
  if __name__ == "__main__":
18
  main()
19
+
20
+
context_pruning_env/utils.py CHANGED
@@ -35,57 +35,44 @@ class SQuADLoader:
35
 
36
  chunks = []
37
 
38
- if task_name == "noise_filter":
39
- # 1 Gold, 4 Noise from random entries
40
  chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
41
- for _ in range(4):
42
- _, noise_entry = self._get_next_entry()
43
- chunks.append({"content": noise_entry["context"], "is_gold": False, "is_duplicate": False})
44
 
45
- elif task_name == "deduplication":
46
- # 2 Gold (Identical), 3 Noise
47
  chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
48
- chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": True})
49
- for _ in range(3):
50
- _, noise_entry = self._get_next_entry()
51
- chunks.append({"content": noise_entry["context"], "is_gold": False, "is_duplicate": False})
52
 
53
- elif task_name == "sentence_distillation":
54
- # Split gold context into sentences. Take the one with the answer.
55
- # Fill remaining slots with other sentences from the same context.
56
- sentences = re.split(r'(?<=[.!?])\s+', gold_context)
57
- answer_text = entry["answers"]["text"][0]
58
-
59
- gold_sentence = None
60
- other_sentences = []
61
- for s in sentences:
62
- if answer_text.lower() in s.lower() and gold_sentence is None:
63
- gold_sentence = s
64
- else:
65
- other_sentences.append(s)
66
-
67
- if gold_sentence is None:
68
- # Fallback if answer spans multiple sentences or is not found cleanly
69
- gold_sentence = sentences[0]
70
- other_sentences = sentences[1:]
71
-
72
- chunks.append({"content": gold_sentence, "is_gold": True, "is_duplicate": False})
73
 
74
- # Sample 4 more or fill with random if not enough sentences
75
- random.shuffle(other_sentences)
76
- for i in range(4):
77
- if i < len(other_sentences):
78
- chunks.append({"content": other_sentences[i], "is_gold": False, "is_duplicate": False})
79
- else:
80
- _, noise_entry = self._get_next_entry()
81
- chunks.append({"content": noise_entry["context"][:100], "is_gold": False, "is_duplicate": False})
82
 
83
  else:
84
- # Default to noise_filter
85
- return self.get_episode("noise_filter")
86
 
87
- # Shuffle all tasks
88
- random.shuffle(chunks)
 
 
89
  return question, chunks
90
 
91
  def count_tokens(text: str) -> int:
 
35
 
36
  chunks = []
37
 
38
+ if task_name == "noise_purge":
39
+ # Easy: 1 Gold + 1 Irrelevant
40
  chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
41
+ _, noise_entry = self._get_next_entry()
42
+ chunks.append({"content": noise_entry["context"], "is_gold": False, "is_duplicate": False})
 
43
 
44
+ elif task_name == "dedupe_arena":
45
+ # Medium: 1 Gold + 2 Near-Duplicates (Simulated by repeating gold)
46
  chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
47
+ # Duplicate 1: slightly modified or identical
48
+ chunks.append({"content": gold_context + " ", "is_gold": True, "is_duplicate": True})
49
+ # Duplicate 2: slightly modified
50
+ chunks.append({"content": "Actually, " + gold_context, "is_gold": True, "is_duplicate": True})
51
 
52
+ elif task_name == "signal_extract":
53
+ # Hard: 1 Long context (2,000+ words)
54
+ # We simulate this by taking 10 random SQuAD contexts and joining them.
55
+ # Only one contains the answer.
56
+ long_context_parts = []
57
+ long_context_parts.append(gold_context)
58
+ for _ in range(15): # ~15 chunks of ~150 words = ~2250 words
59
+ _, noise_entry = self._get_next_entry()
60
+ long_context_parts.append(noise_entry["context"])
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Shuffling the parts so the gold one isn't first
63
+ random.shuffle(long_context_parts)
64
+ for part in long_context_parts:
65
+ is_gold = (part == gold_context)
66
+ chunks.append({"content": part, "is_gold": is_gold, "is_duplicate": False})
 
 
 
67
 
68
  else:
69
+ # Default to noise_purge
70
+ return self.get_episode("noise_purge")
71
 
72
+ # Shuffle chunks for non-signal tasks
73
+ if task_name != "signal_extract":
74
+ random.shuffle(chunks)
75
+
76
  return question, chunks
77
 
78
  def count_tokens(text: str) -> int:
inference.py CHANGED
@@ -2,78 +2,91 @@ import os
2
  import json
3
  import logging
4
  import re
5
- import google.generativeai as genai
 
6
  from dotenv import load_dotenv
7
  from context_pruning_env.env import ContextPruningEnv
8
-
9
- # Load API keys from .env
10
- load_dotenv()
11
  from context_pruning_env.models import ContextAction
12
 
13
- # Setup simple logging
14
- logging.basicConfig(level=logging.INFO)
15
- logger = logging.getLogger(__name__)
 
 
16
 
17
- # Configure Gemini
18
- GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
19
- if GOOGLE_API_KEY:
20
- genai.configure(api_key=GOOGLE_API_KEY)
21
 
22
- def main():
23
- if not GOOGLE_API_KEY:
24
- logger.error("GOOGLE_API_KEY not found in environment or .env file.")
25
  return
26
 
27
- # 1. Setup Gemini Model
28
- model = genai.GenerativeModel("gemini-1.5-flash")
29
-
30
- # 2. Initialize Environment
31
- env = ContextPruningEnv(squad_split="train")
32
-
33
- # Run a few episodes across different tasks
34
- tasks = ["noise_filter", "deduplication", "sentence_distillation"]
35
 
36
- for task_name in tasks:
37
- logger.info(f"--- Running Task: {task_name} ---")
 
 
 
38
 
39
- # 3. Reset (Observation)
40
- obs = env.reset(task_name=task_name)
41
- print(f"<OBSERVATION>{obs.model_dump_json()}</OBSERVATION>")
42
 
43
- # 4. Agent Logic (Gemini Call)
44
  prompt = (
 
45
  f"Question: {obs.question}\n\n"
46
- "Below are 5 context chunks. Output ONLY a JSON list of 5 integers (0 or 1) "
47
- "where 1 means 'keep' and 0 means 'prune'. "
48
- "Prioritize keeping the answer while removing noise and duplicates.\n"
49
- f"Chunks: {json.dumps(obs.chunks, indent=2)}\n\n"
50
- "Action format: [1, 0, 1, 1, 0]"
51
  )
 
 
 
 
52
 
53
  try:
54
- response = model.generate_content(prompt)
55
- completion = response.text
 
 
 
 
56
 
57
- # Simple extraction of the mask [x,x,x,x,x]
58
- match = re.search(r"\[\s*([01])\s*,\s*([01])\s*,\s*([01])\s*,\s*([01])\s*,\s*([01])\s*\]", completion)
59
  if match:
60
- mask = [int(m) for m in match.groups()]
61
  else:
62
- logger.warning(f"Failed to parse mask from Gemini output: {completion}. Falling back to [1,1,1,1,1]")
63
- mask = [1, 1, 1, 1, 1]
64
  except Exception as e:
65
- logger.error(f"Gemini Inference failed: {e}")
66
- mask = [1, 1, 1, 1, 1]
67
 
68
- # 5. Take Action
69
  action = ContextAction(mask=mask)
70
- print(f"<ACTION>{action.model_dump_json()}</ACTION>")
71
-
72
- # 6. Step (Reward)
73
  final_obs = env.step(action)
74
- print(f"<REWARD>{json.dumps({'score': final_obs.reward, 'message': final_obs.message})}</REWARD>")
75
 
76
- logger.info(f"Task {task_name} Result: {final_obs.message} (Score: {final_obs.reward})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  if __name__ == "__main__":
79
- main()
 
2
  import json
3
  import logging
4
  import re
5
+ from typing import List, Optional
6
+ from openai import OpenAI
7
  from dotenv import load_dotenv
8
  from context_pruning_env.env import ContextPruningEnv
 
 
 
9
  from context_pruning_env.models import ContextAction
10
 
11
+ # Setup mandatory log format for Meta x Scaler Evaluator
12
+ logging.basicConfig(level=logging.INFO, format='%(message)s')
13
+ logger = logging.getLogger("evaluator")
14
+
15
+ load_dotenv()
16
 
17
+ # Mandatory Environment Variables
18
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/")
19
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gemini-1.5-flash")
20
+ HF_TOKEN = os.environ.get("HF_TOKEN", os.environ.get("GOOGLE_API_KEY", ""))
21
 
22
+ def run_inference():
23
+ if not HF_TOKEN:
24
+ print("ERROR: HF_TOKEN (or GOOGLE_API_KEY) not found.")
25
  return
26
 
27
+ client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
28
+ env = ContextPruningEnv()
 
 
 
 
 
 
29
 
30
+ tasks = ["noise_purge", "dedupe_arena", "signal_extract"]
31
+
32
+ for task in tasks:
33
+ # [START] tag for automated evaluation
34
+ print(f"[START] task={task} env=contextprune model={MODEL_NAME}")
35
 
36
+ obs = env.reset(task_name=task)
 
 
37
 
38
+ step_n = 1
39
  prompt = (
40
+ f"Task: {task}\n"
41
  f"Question: {obs.question}\n\n"
42
+ "Chunks:\n"
 
 
 
 
43
  )
44
+ for i, c in enumerate(obs.chunks):
45
+ prompt += f"[{i}]: {c}\n"
46
+
47
+ prompt += "\nOutput ONLY a JSON list of indices (0 or 1) for each chunk. Example: [1, 0, 1]"
48
 
49
  try:
50
+ response = client.chat.completions.create(
51
+ model=MODEL_NAME,
52
+ messages=[{"role": "user", "content": prompt}],
53
+ temperature=0.0
54
+ )
55
+ content = response.choices[0].message.content
56
 
57
+ match = re.search(r"\[([\d\s,]+)\]", content)
 
58
  if match:
59
+ mask = json.loads(match.group(0))
60
  else:
61
+ mask = [1] * len(obs.chunks)
 
62
  except Exception as e:
63
+ logger.error(f"Inference Error: {e}")
64
+ mask = [1] * len(obs.chunks)
65
 
66
+ # Execute Action
67
  action = ContextAction(mask=mask)
 
 
 
68
  final_obs = env.step(action)
 
69
 
70
+ # [STEP] tag for each action in the trajectory
71
+ step_log = (
72
+ f"[STEP] task={task} "
73
+ f"step={step_n} "
74
+ f"action={json.dumps(mask)} "
75
+ f"reward={final_obs.reward:.2f} "
76
+ f"done={str(final_obs.done).lower()}"
77
+ )
78
+ print(step_log)
79
+
80
+ # [END] tag for episode completion
81
+ score = final_obs.metadata.get('eval_score', 0)
82
+ success = score > 0.5
83
+ end_log = (
84
+ f"[END] task={task} "
85
+ f"score={score:.2f} "
86
+ f"success={str(success).lower()} "
87
+ f"rewards={final_obs.reward:.2f}"
88
+ )
89
+ print(end_log)
90
 
91
  if __name__ == "__main__":
92
+ run_inference()
openenv.yaml CHANGED
@@ -1,12 +1,13 @@
1
  spec_version: 1
2
- name: ContextPrune
 
3
  type: space
4
- runtime: fastapi
5
  app: context_pruning_env.server.app:app
6
- port: 7860
7
  resources:
8
  cpu: 2
9
  memory: 8Gi
10
  storage: 10Gi
11
  timeout: 300
12
- description: "Adaptive Context Optimization Agent (ContextPrune): Reduces noise and tokens in RAG pipelines while preserving answer quality."
 
1
  spec_version: 1
2
+ name: contextprune
3
+ version: 0.1.0
4
  type: space
5
+ runtime: python
6
  app: context_pruning_env.server.app:app
7
+ port: 8000
8
  resources:
9
  cpu: 2
10
  memory: 8Gi
11
  storage: 10Gi
12
  timeout: 300
13
+ description: "ContextPrune: Adaptive Context Optimization Environment (Meta x Scaler Round 1 Compliance)."
requirements.txt CHANGED
@@ -11,4 +11,4 @@ python-dotenv>=1.0.0
11
  pytest>=7.4.0
12
  gradio>=4.0.0
13
  google-generativeai>=0.3.0
14
- python-dotenv>=1.0.0
 
11
  pytest>=7.4.0
12
  gradio>=4.0.0
13
  google-generativeai>=0.3.0
14
+ openai>=1.0.0
stderr.log ADDED
Binary file (3.9 kB). View file
 
stderr_utf8.log ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python : D:\Projects\RAG\context_pr
2
+ uning_env\models.py:10:
3
+ PydanticDeprecatedSince20:
4
+ `min_items` is deprecated and will
5
+ be removed, use `min_length`
6
+ instead. Deprecated in Pydantic
7
+ V2.0 to be removed in V3.0. See
8
+ Pydantic V2 Migration Guide at http
9
+ s://errors.pydantic.dev/2.12/migrat
10
+ ion/
11
+ At line:1 char:1
12
+ + python inference.py 2> stderr.log
13
+ + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
14
+ + CategoryInfo : NotS
15
+ pecified: (D:\Projects\RAG...2
16
+ .12/migration/:String) [], Rem
17
+ oteException
18
+ + FullyQualifiedErrorId : Nati
19
+ veCommandError
20
+
21
+ mask: List[int] = Field(
22
+ HTTP Request: POST https://generati
23
+ velanguage.googleapis.com/v1beta/op
24
+ enai/chat/completions "HTTP/1.1
25
+ 404 Not Found"
26
+ Inference Error: Error code: 404 -
27
+ [{'error': {'code': 404,
28
+ 'message':
29
+ 'models/gemini-1.5-flash is not
30
+ found for API version v1main, or
31
+ is not supported for
32
+ generateContent. Call ListModels
33
+ to see the list of available
34
+ models and their supported
35
+ methods.', 'status': 'NOT_FOUND'}}]
36
+ HTTP Request: POST https://generati
37
+ velanguage.googleapis.com/v1beta/op
38
+ enai/chat/completions "HTTP/1.1
39
+ 404 Not Found"
40
+ Inference Error: Error code: 404 -
41
+ [{'error': {'code': 404,
42
+ 'message':
43
+ 'models/gemini-1.5-flash is not
44
+ found for API version v1main, or
45
+ is not supported for
46
+ generateContent. Call ListModels
47
+ to see the list of available
48
+ models and their supported
49
+ methods.', 'status': 'NOT_FOUND'}}]
50
+ HTTP Request: POST https://generati
51
+ velanguage.googleapis.com/v1beta/op
52
+ enai/chat/completions "HTTP/1.1
53
+ 404 Not Found"
54
+ Inference Error: Error code: 404 -
55
+ [{'error': {'code': 404,
56
+ 'message':
57
+ 'models/gemini-1.5-flash is not
58
+ found for API version v1main, or
59
+ is not supported for
60
+ generateContent. Call ListModels
61
+ to see the list of available
62
+ models and their supported
63
+ methods.', 'status': 'NOT_FOUND'}}]
stdout.log ADDED
Binary file (1.05 kB). View file
 
stdout_utf8.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ task=noise_purge env=contextprune model=gemini-1.5-flash step=1 action=[1, 1] reward=0.70 done=true error=null success=false steps=1 score=0.50 rewards=0.70
2
+ task=dedupe_arena env=contextprune model=gemini-1.5-flash step=1 action=[1, 1, 1] reward=0.70 done=true error=null success=false steps=1 score=0.50 rewards=0.70
3
+ task=signal_extract env=contextprune model=gemini-1.5-flash step=1 action=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] reward=0.70 done=true error=null success=false steps=1 score=0.10 rewards=0.70