Spaces:
Sleeping
Sleeping
Meta x Scaler Compliance: Strict logging, Port 8000 sync, and mandatory env vars.
Browse files- Dockerfile +2 -17
- README.md +57 -39
- app_ui.py +47 -20
- context_pruning_env/env.py +56 -21
- context_pruning_env/graders.py +31 -31
- context_pruning_env/models.py +12 -12
- context_pruning_env/server/app.py +6 -4
- context_pruning_env/utils.py +30 -43
- inference.py +62 -49
- openenv.yaml +5 -4
- requirements.txt +1 -1
- stderr.log +0 -0
- stderr_utf8.log +63 -0
- stdout.log +0 -0
- stdout_utf8.log +3 -0
Dockerfile
CHANGED
|
@@ -7,23 +7,8 @@ ENV PYTHONPATH=/app
|
|
| 7 |
WORKDIR /app
|
| 8 |
|
| 9 |
# Install system dependencies
|
| 10 |
-
|
| 11 |
-
build-essential \
|
| 12 |
-
curl \
|
| 13 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
-
|
| 15 |
-
# Optimize for 2 vCPU and 8GB RAM
|
| 16 |
-
# Copy and install Python dependencies separately for layer caching
|
| 17 |
COPY requirements.txt .
|
| 18 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 19 |
|
| 20 |
-
|
| 21 |
-
COPY context_pruning_env ./context_pruning_env
|
| 22 |
-
COPY inference.py .
|
| 23 |
-
COPY openenv.yaml .
|
| 24 |
-
|
| 25 |
-
# Expose the standard OpenEnv port
|
| 26 |
-
EXPOSE 7860
|
| 27 |
-
|
| 28 |
-
# FastAPI app entrypoint
|
| 29 |
-
CMD ["uvicorn", "context_pruning_env.server.app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "2"]
|
|
|
|
| 7 |
WORKDIR /app
|
| 8 |
|
| 9 |
# Install system dependencies
|
| 10 |
+
# Install dependencies
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
COPY requirements.txt .
|
| 12 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
|
| 14 |
+
CMD ["uvicorn", "context_pruning_env.server.app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,62 +1,80 @@
|
|
| 1 |
-
# Adaptive Context Optimization
|
| 2 |
|
| 3 |
-
**ContextPrune** is a
|
| 4 |
-
|
| 5 |
-
> “ContextPrune reduces noise and tokens in RAG pipelines while preserving answer quality.”
|
| 6 |
|
| 7 |
---
|
| 8 |
|
| 9 |
-
##
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
# Install dependencies
|
| 21 |
-
pip install -r requirements.txt
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
#
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
|
| 34 |
-
|
| 35 |
-
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
```bash
|
| 44 |
-
#
|
| 45 |
-
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
|
| 49 |
-
```
|
| 50 |
|
| 51 |
-
#
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
<ACTION>{...}</ACTION>
|
| 58 |
-
<REWARD>{...}</REWARD>
|
| 59 |
```
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
---
|
| 62 |
-
*
|
|
|
|
| 1 |
+
# ContextPrune: Adaptive Context Optimization Environment
|
| 2 |
|
| 3 |
+
**ContextPrune** is a Meta x Scaler Hackathon compliant reinforcement learning environment designed for Phase 1: Automated Validation. It focuses on the critical task of context pruning for RAG pipelines, reducing noise and token counts while strictly preserving answer faithfulness.
|
|
|
|
|
|
|
| 4 |
|
| 5 |
---
|
| 6 |
|
| 7 |
+
## 🌍 Environment Description
|
| 8 |
+
ContextPrune implements the **OpenEnv Spec**, providing a standardized interface for RL agents to optimize retrieved contexts. The environment presents a query and multiple context chunks (from SQuAD or synthetic noise) where the agent must decide which chunks to keep and which to prune using a binary mask.
|
| 9 |
|
| 10 |
+
### Resource Constraints
|
| 11 |
+
- **vCPU**: 2
|
| 12 |
+
- **RAM**: 8GB
|
| 13 |
+
- **Runtime**: Python 3.10+
|
| 14 |
+
- **Port**: 8000 (OpenEnv Server)
|
| 15 |
|
| 16 |
+
---
|
| 17 |
|
| 18 |
+
## 🎮 Action & Observation Spaces
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
### Action Space (ContextAction)
|
| 21 |
+
- **Type**: Binary Mask (`List[int]`)
|
| 22 |
+
- **Values**: `1` (Keep), `0` (Prune)
|
| 23 |
+
- **Constraint**: Must match the number of chunks in the current observation.
|
| 24 |
|
| 25 |
+
### Observation Space (ContextObservation)
|
| 26 |
+
- **question**: The user query to be answered.
|
| 27 |
+
- **chunks**: A list of text strings representing the retrieved context.
|
| 28 |
+
- **initial_token_count**: The total token count before optimization.
|
| 29 |
+
- **current_token_count**: Cumulative tokens of the currently selected chunks.
|
| 30 |
+
- **task_name**: The identifier for the current pruning task.
|
| 31 |
|
| 32 |
+
---
|
| 33 |
|
| 34 |
+
## 🏆 Task Descriptions
|
| 35 |
|
| 36 |
+
| Task ID | Name | Difficulty | Scoring Logic |
|
| 37 |
+
| :--- | :--- | :--- | :--- |
|
| 38 |
+
| **01** | `noise_purge` | **Easy** | 0.0 or 1.0. Perfect score if all noise is deleted and the answer is kept. |
|
| 39 |
+
| **02** | `dedupe_arena` | **Medium** | 1.0 if word count is reduced by >50% while preserving the answer. |
|
| 40 |
+
| **03** | `signal_extract` | **Hard** | $1 - (FinalTokens/InitialTokens)$. Score scales with compression ratio. |
|
| 41 |
|
| 42 |
+
---
|
| 43 |
|
| 44 |
+
## 📈 Reward Function (Trajectory Signals)
|
| 45 |
+
The environment emits rewards based on the agent's efficiency and accuracy:
|
| 46 |
+
- **Efficiency**: `+0.1` for every irrelevant chunk or duplicate correctly pruned.
|
| 47 |
+
- **Accuracy**: `+0.7` bonus at the end of the trajectory if the "Gold Chunk" is preserved.
|
| 48 |
+
- **Death Penalty**: `-1.0` and immediate `done=True` if the agent prunes the Gold Chunk (Information Loss).
|
| 49 |
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## 🛠️ Setup Instructions
|
| 53 |
+
|
| 54 |
+
### 1. Local Development
|
| 55 |
```bash
|
| 56 |
+
# Install dependencies
|
| 57 |
+
pip install -r requirements.txt
|
| 58 |
|
| 59 |
+
# Configure API (Optional for testing)
|
| 60 |
+
echo "GOOGLE_API_KEY=your_key" > .env
|
|
|
|
| 61 |
|
| 62 |
+
# Run Inference Evaluation
|
| 63 |
+
python inference.py
|
| 64 |
+
```
|
| 65 |
|
| 66 |
+
### 2. Docker Deployment
|
| 67 |
+
```bash
|
| 68 |
+
# Build the standardized image
|
| 69 |
+
docker build -t contextprune .
|
| 70 |
|
| 71 |
+
# Start the environment server
|
| 72 |
+
docker run -p 8000:8000 contextprune
|
|
|
|
|
|
|
| 73 |
```
|
| 74 |
|
| 75 |
+
### 3. Inference Logging
|
| 76 |
+
Mandatory logs are emitted in the following format for the Hackathon Evaluator:
|
| 77 |
+
`task=<name> env=contextprune model=<model> step=<n> action=<str> reward=<0.00> done=<bool> score=<score> rewards=<r1,r2...>`
|
| 78 |
+
|
| 79 |
---
|
| 80 |
+
*Built for the Meta x Scaler Hackathon 2026*
|
app_ui.py
CHANGED
|
@@ -30,16 +30,20 @@ async def call_gemini(prompt: str, model_name: str = "gemini-1.5-flash") -> str:
|
|
| 30 |
except Exception as e:
|
| 31 |
return f"ERROR: {str(e)}"
|
| 32 |
|
| 33 |
-
def chunk_text(text: str, max_chunks: int =
|
| 34 |
"""Split text into manageable chunks (paragraphs or sentences)."""
|
| 35 |
-
#
|
| 36 |
-
|
| 37 |
-
if len(chunks) < 2:
|
| 38 |
-
# Split by sentence if only one paragraph
|
| 39 |
-
chunks = [c.strip() for c in re.split(r'(?<=[.!?])\s+', text) if c.strip()]
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]:
|
| 45 |
"""
|
|
@@ -53,26 +57,49 @@ async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]:
|
|
| 53 |
# Prompt for selection
|
| 54 |
selection_prompt = (
|
| 55 |
f"Query: {query}\n\n"
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"
|
| 59 |
"Chunks:\n"
|
| 60 |
)
|
| 61 |
for i, c in enumerate(chunks):
|
| 62 |
selection_prompt += f"Chunk {i}: {c}\n\n"
|
| 63 |
|
| 64 |
raw_response = await call_gemini(selection_prompt)
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
else:
|
| 75 |
-
|
|
|
|
|
|
|
| 76 |
|
| 77 |
optimized_text = " ".join(kept_chunks)
|
| 78 |
|
|
|
|
| 30 |
except Exception as e:
|
| 31 |
return f"ERROR: {str(e)}"
|
| 32 |
|
| 33 |
+
def chunk_text(text: str, max_chunks: int = 10) -> List[str]:
|
| 34 |
"""Split text into manageable chunks (paragraphs or sentences)."""
|
| 35 |
+
# 1. First split by double newlines (paragraphs)
|
| 36 |
+
initial_chunks = [c.strip() for c in re.split(r'\n\s*\n', text) if c.strip()]
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
final_chunks = []
|
| 39 |
+
# 2. If paragraphs are too few or long, split them into sentences
|
| 40 |
+
for chunk in initial_chunks:
|
| 41 |
+
# Split by sentence markers [.!?] followed by space or newline
|
| 42 |
+
sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+|\n', chunk) if s.strip()]
|
| 43 |
+
final_chunks.extend(sentences)
|
| 44 |
+
|
| 45 |
+
# Simple limit to 10 chunks to avoid overwhelming the prompt
|
| 46 |
+
return final_chunks[:max_chunks]
|
| 47 |
|
| 48 |
async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]:
|
| 49 |
"""
|
|
|
|
| 57 |
# Prompt for selection
|
| 58 |
selection_prompt = (
|
| 59 |
f"Query: {query}\n\n"
|
| 60 |
+
"TASK: Select indices of context chunks that are directly relevant to the query. "
|
| 61 |
+
"Remove noise, random facts, and duplicates. "
|
| 62 |
+
"OUTPUT: Output ONLY the list of indices as a JSON array like [0, 2, 4]. No explanations.\n\n"
|
| 63 |
"Chunks:\n"
|
| 64 |
)
|
| 65 |
for i, c in enumerate(chunks):
|
| 66 |
selection_prompt += f"Chunk {i}: {c}\n\n"
|
| 67 |
|
| 68 |
raw_response = await call_gemini(selection_prompt)
|
| 69 |
+
print(f"DEBUG: Gemini Response: {raw_response}")
|
| 70 |
|
| 71 |
+
from context_pruning_env.graders import (
|
| 72 |
+
grade_noise_purge,
|
| 73 |
+
grade_dedupe_arena,
|
| 74 |
+
grade_signal_extract
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Ultra-robust extraction
|
| 78 |
+
indices = []
|
| 79 |
+
try:
|
| 80 |
+
match = re.search(r"\[([\d\s,]+)\]", raw_response)
|
| 81 |
+
if match:
|
| 82 |
+
# Found a bracketed list of numbers
|
| 83 |
+
content = match.group(0) # e.g. "[0, 1, 2]"
|
| 84 |
+
indices = json.loads(content)
|
| 85 |
+
else:
|
| 86 |
+
# Try finding any numbers in the response if no brackets
|
| 87 |
+
nums = re.findall(r"\d+", raw_response)
|
| 88 |
+
indices = [int(n) for n in nums]
|
| 89 |
+
|
| 90 |
+
# Clean up: only valid unique indices
|
| 91 |
+
indices = list(set([int(i) for i in indices if isinstance(i, int) and 0 <= i < len(chunks)]))
|
| 92 |
+
print(f"DEBUG: Successfully extracted indices: {indices}")
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"DEBUG: Extraction Error: {e}")
|
| 95 |
+
indices = []
|
| 96 |
+
|
| 97 |
+
if indices:
|
| 98 |
+
kept_chunks = [chunks[i] for i in sorted(indices)]
|
| 99 |
else:
|
| 100 |
+
# Fallback to keep everything if AI fails, but message it
|
| 101 |
+
print("DEBUG: Pruning failed, keeping original context.")
|
| 102 |
+
kept_chunks = chunks
|
| 103 |
|
| 104 |
optimized_text = " ".join(kept_chunks)
|
| 105 |
|
context_pruning_env/env.py
CHANGED
|
@@ -12,9 +12,9 @@ from context_pruning_env.models import (
|
|
| 12 |
)
|
| 13 |
from context_pruning_env.utils import SQuADLoader, count_tokens
|
| 14 |
from context_pruning_env.graders import (
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
)
|
| 19 |
|
| 20 |
class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningState]):
|
|
@@ -31,13 +31,14 @@ class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningSt
|
|
| 31 |
self,
|
| 32 |
seed: Optional[int] = None,
|
| 33 |
episode_id: Optional[str] = None,
|
| 34 |
-
task_name: Optional[str] = "
|
| 35 |
**kwargs: Any,
|
| 36 |
) -> ContextObservation:
|
| 37 |
"""
|
| 38 |
Starts a new episode with the specified task.
|
| 39 |
"""
|
| 40 |
-
|
|
|
|
| 41 |
|
| 42 |
chunks = []
|
| 43 |
total_tokens = 0
|
|
@@ -53,7 +54,7 @@ class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningSt
|
|
| 53 |
|
| 54 |
self._state = PruningState(
|
| 55 |
episode_id=episode_id or str(uuid4()),
|
| 56 |
-
task_name=task_name
|
| 57 |
question=question,
|
| 58 |
chunks=chunks,
|
| 59 |
initial_tokens=total_tokens,
|
|
@@ -69,7 +70,8 @@ class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningSt
|
|
| 69 |
done=self._state.done,
|
| 70 |
question=self._state.question,
|
| 71 |
chunks=[c.content for c in self._state.chunks],
|
| 72 |
-
|
|
|
|
| 73 |
task_name=self._state.task_name,
|
| 74 |
message=message
|
| 75 |
)
|
|
@@ -80,34 +82,67 @@ class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningSt
|
|
| 80 |
**kwargs: Any,
|
| 81 |
) -> ContextObservation:
|
| 82 |
"""
|
| 83 |
-
Takes a binary mask and
|
| 84 |
"""
|
| 85 |
if self._state.done:
|
| 86 |
return self._observe(message="Episode is already done.")
|
| 87 |
|
| 88 |
mask = action.mask
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
else:
|
| 98 |
-
|
| 99 |
|
| 100 |
self._state.done = True
|
| 101 |
self._state.step_count += 1
|
| 102 |
|
| 103 |
-
|
| 104 |
-
obs =
|
| 105 |
-
obs.reward = reward_obj.score
|
| 106 |
|
| 107 |
-
# Add detailed reward to metadata for hackathon transparency
|
| 108 |
if not obs.metadata:
|
| 109 |
obs.metadata = {}
|
| 110 |
-
obs.metadata["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
return obs
|
| 113 |
|
|
|
|
| 12 |
)
|
| 13 |
from context_pruning_env.utils import SQuADLoader, count_tokens
|
| 14 |
from context_pruning_env.graders import (
|
| 15 |
+
grade_noise_purge,
|
| 16 |
+
grade_dedupe_arena,
|
| 17 |
+
grade_signal_extract
|
| 18 |
)
|
| 19 |
|
| 20 |
class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningState]):
|
|
|
|
| 31 |
self,
|
| 32 |
seed: Optional[int] = None,
|
| 33 |
episode_id: Optional[str] = None,
|
| 34 |
+
task_name: Optional[str] = "noise_purge",
|
| 35 |
**kwargs: Any,
|
| 36 |
) -> ContextObservation:
|
| 37 |
"""
|
| 38 |
Starts a new episode with the specified task.
|
| 39 |
"""
|
| 40 |
+
task_name = task_name or "noise_purge"
|
| 41 |
+
question, chunks_data = self.loader.get_episode(task_name)
|
| 42 |
|
| 43 |
chunks = []
|
| 44 |
total_tokens = 0
|
|
|
|
| 54 |
|
| 55 |
self._state = PruningState(
|
| 56 |
episode_id=episode_id or str(uuid4()),
|
| 57 |
+
task_name=task_name,
|
| 58 |
question=question,
|
| 59 |
chunks=chunks,
|
| 60 |
initial_tokens=total_tokens,
|
|
|
|
| 70 |
done=self._state.done,
|
| 71 |
question=self._state.question,
|
| 72 |
chunks=[c.content for c in self._state.chunks],
|
| 73 |
+
initial_token_count=self._state.initial_tokens,
|
| 74 |
+
current_token_count=sum(c.tokens for c in self._state.chunks),
|
| 75 |
task_name=self._state.task_name,
|
| 76 |
message=message
|
| 77 |
)
|
|
|
|
| 82 |
**kwargs: Any,
|
| 83 |
) -> ContextObservation:
|
| 84 |
"""
|
| 85 |
+
Takes a binary mask and calculates rewards based on trajectory signals.
|
| 86 |
"""
|
| 87 |
if self._state.done:
|
| 88 |
return self._observe(message="Episode is already done.")
|
| 89 |
|
| 90 |
mask = action.mask
|
| 91 |
+
if len(mask) != len(self._state.chunks):
|
| 92 |
+
# Pad or truncate mask to match chunk count if agent is misaligned
|
| 93 |
+
mask = (mask + [1] * len(self._state.chunks))[:len(self._state.chunks)]
|
| 94 |
|
| 95 |
+
# Trajectory Simulation Logic
|
| 96 |
+
total_reward = 0.0
|
| 97 |
+
efficiency_reward = 0.0
|
| 98 |
+
accuracy_reward = 0.0
|
| 99 |
+
gold_penalty = 0.0
|
| 100 |
+
success = True
|
| 101 |
+
|
| 102 |
+
for i, kept in enumerate(mask):
|
| 103 |
+
chunk = self._state.chunks[i]
|
| 104 |
+
if not kept: # Pruned
|
| 105 |
+
if chunk.is_gold:
|
| 106 |
+
# Critical Failure
|
| 107 |
+
gold_penalty = -1.0
|
| 108 |
+
success = False
|
| 109 |
+
break # Immediate stop
|
| 110 |
+
else:
|
| 111 |
+
# Correctly pruned noise/duplicate
|
| 112 |
+
efficiency_reward += 0.1
|
| 113 |
+
else: # Kept
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
# Final Accuracy Bonus
|
| 117 |
+
if success:
|
| 118 |
+
accuracy_reward = 0.7
|
| 119 |
+
|
| 120 |
+
total_reward = efficiency_reward + accuracy_reward + gold_penalty
|
| 121 |
+
|
| 122 |
+
# Task Score (Normalized 0.0 to 1.0 for the evaluator)
|
| 123 |
+
if self._state.task_name == "noise_purge":
|
| 124 |
+
score_obj = grade_noise_purge(mask, self._state.chunks)
|
| 125 |
+
elif self._state.task_name == "dedupe_arena":
|
| 126 |
+
score_obj = grade_dedupe_arena(mask, self._state.chunks)
|
| 127 |
+
elif self._state.task_name == "signal_extract":
|
| 128 |
+
score_obj = grade_signal_extract(mask, self._state.chunks)
|
| 129 |
else:
|
| 130 |
+
score_obj = grade_noise_purge(mask, self._state.chunks)
|
| 131 |
|
| 132 |
self._state.done = True
|
| 133 |
self._state.step_count += 1
|
| 134 |
|
| 135 |
+
obs = self._observe(message=score_obj.message)
|
| 136 |
+
obs.reward = total_reward # Trajectory reward
|
|
|
|
| 137 |
|
|
|
|
| 138 |
if not obs.metadata:
|
| 139 |
obs.metadata = {}
|
| 140 |
+
obs.metadata["eval_score"] = score_obj.score # Grader score
|
| 141 |
+
obs.metadata["reward_detail"] = {
|
| 142 |
+
"efficiency": efficiency_reward,
|
| 143 |
+
"accuracy": accuracy_reward,
|
| 144 |
+
"penalty": gold_penalty
|
| 145 |
+
}
|
| 146 |
|
| 147 |
return obs
|
| 148 |
|
context_pruning_env/graders.py
CHANGED
|
@@ -1,56 +1,56 @@
|
|
| 1 |
from typing import List
|
| 2 |
from context_pruning_env.models import ChunkItem, ContextReward
|
| 3 |
|
| 4 |
-
def
|
| 5 |
"""
|
| 6 |
-
|
| 7 |
"""
|
| 8 |
gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
|
| 9 |
noise_pruned = all(mask[i] == 0 for i in range(len(mask)) if not chunks[i].is_gold)
|
| 10 |
|
| 11 |
if not gold_kept:
|
| 12 |
-
return ContextReward(score=0.0,
|
| 13 |
|
| 14 |
if noise_pruned:
|
| 15 |
-
return ContextReward(score=1.0,
|
| 16 |
else:
|
| 17 |
-
return ContextReward(score=0.5,
|
| 18 |
|
| 19 |
-
def
|
| 20 |
"""
|
| 21 |
-
|
| 22 |
"""
|
| 23 |
-
|
| 24 |
-
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
|
| 29 |
-
if
|
| 30 |
-
return ContextReward(score=0.0, message="Critical
|
| 31 |
|
| 32 |
-
if
|
| 33 |
-
return ContextReward(score=1.0, message="
|
| 34 |
-
elif kept_gold_count > 1:
|
| 35 |
-
return ContextReward(score=0.4, message="Partial: Duplicates detected.")
|
| 36 |
else:
|
| 37 |
-
return ContextReward(score=0.5, message="Partial:
|
| 38 |
|
| 39 |
-
def
|
| 40 |
"""
|
| 41 |
-
|
| 42 |
"""
|
| 43 |
-
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
| 47 |
-
pruned_noise_count = sum(1 for i in noise_indices if mask[i] == 0)
|
| 48 |
|
| 49 |
if not gold_kept:
|
| 50 |
-
return ContextReward(score=0.0, message="Critical
|
| 51 |
-
|
| 52 |
-
if
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import List
|
| 2 |
from context_pruning_env.models import ChunkItem, ContextReward
|
| 3 |
|
| 4 |
+
def grade_noise_purge(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
|
| 5 |
"""
|
| 6 |
+
Easy Task: Score 1.0 if gold kept AND noise pruned.
|
| 7 |
"""
|
| 8 |
gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
|
| 9 |
noise_pruned = all(mask[i] == 0 for i in range(len(mask)) if not chunks[i].is_gold)
|
| 10 |
|
| 11 |
if not gold_kept:
|
| 12 |
+
return ContextReward(score=0.0, gold_penalty=-1.0, message="Critical: Gold chunk lost.")
|
| 13 |
|
| 14 |
if noise_pruned:
|
| 15 |
+
return ContextReward(score=1.0, message="Perfect: All noise purged.")
|
| 16 |
else:
|
| 17 |
+
return ContextReward(score=0.5, message="Partial: Gold kept but noise remains.")
|
| 18 |
|
| 19 |
+
def grade_dedupe_arena(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
|
| 20 |
"""
|
| 21 |
+
Medium Task: 1.0 if word count reduced > 50% AND gold kept.
|
| 22 |
"""
|
| 23 |
+
initial_words = sum(len(c.content.split()) for c in chunks)
|
| 24 |
+
final_words = sum(len(chunks[i].content.split()) for i, kept in enumerate(mask) if kept)
|
| 25 |
|
| 26 |
+
gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
|
| 27 |
+
reduction = 1.0 - (final_words / initial_words) if initial_words > 0 else 1.0
|
| 28 |
|
| 29 |
+
if not gold_kept:
|
| 30 |
+
return ContextReward(score=0.0, message="Critical: Answer lost during deduplication.")
|
| 31 |
|
| 32 |
+
if reduction >= 0.5:
|
| 33 |
+
return ContextReward(score=1.0, message=f"Great: {reduction:.1%} word reduction achieved.")
|
|
|
|
|
|
|
| 34 |
else:
|
| 35 |
+
return ContextReward(score=0.5, message=f"Partial: Only {reduction:.1%} reduction.")
|
| 36 |
|
| 37 |
+
def grade_signal_extract(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
|
| 38 |
"""
|
| 39 |
+
Hard Task: 1 - (FinalTokens/InitialTokens) if gold kept.
|
| 40 |
"""
|
| 41 |
+
initial_tokens = sum(c.tokens for c in chunks)
|
| 42 |
+
final_tokens = sum(chunks[i].tokens for i, kept in enumerate(mask) if kept)
|
| 43 |
|
| 44 |
+
gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
|
|
|
|
| 45 |
|
| 46 |
if not gold_kept:
|
| 47 |
+
return ContextReward(score=0.0, message="Critical: Signal lost in noise.")
|
| 48 |
+
|
| 49 |
+
reduction_score = 1.0 - (final_tokens / initial_tokens) if initial_tokens > 0 else 0.0
|
| 50 |
+
# Ensure score is at least positive if gold is kept
|
| 51 |
+
final_score = max(0.1, reduction_score)
|
| 52 |
+
|
| 53 |
+
return ContextReward(
|
| 54 |
+
score=final_score,
|
| 55 |
+
message=f"Signal Extracted: {reduction_score:.1%} compression."
|
| 56 |
+
)
|
context_pruning_env/models.py
CHANGED
|
@@ -5,13 +5,12 @@ from openenv.core.env_server.types import Action, Observation, State
|
|
| 5 |
|
| 6 |
class ContextAction(Action):
|
| 7 |
"""
|
| 8 |
-
Action space: A binary mask of
|
| 9 |
"""
|
| 10 |
mask: List[int] = Field(
|
| 11 |
...,
|
| 12 |
-
min_items=
|
| 13 |
-
|
| 14 |
-
description="Binary mask of 5 integers (0 or 1) indicating which chunks to keep."
|
| 15 |
)
|
| 16 |
|
| 17 |
class ContextObservation(Observation):
|
|
@@ -19,31 +18,32 @@ class ContextObservation(Observation):
|
|
| 19 |
Observation provided to the agent.
|
| 20 |
"""
|
| 21 |
question: str
|
| 22 |
-
chunks: List[str] = Field(default_factory=list, description="
|
| 23 |
-
|
|
|
|
| 24 |
task_name: str = ""
|
| 25 |
message: str = ""
|
| 26 |
|
| 27 |
class ContextReward(BaseModel):
|
| 28 |
"""
|
| 29 |
-
Detailed reward breakdown.
|
| 30 |
"""
|
| 31 |
score: float = Field(0.0, ge=0.0, le=1.0, description="Overall task score (0 to 1).")
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
message: str = ""
|
| 36 |
|
| 37 |
class ChunkItem(BaseModel):
|
| 38 |
"""Internal representation of a context chunk."""
|
| 39 |
content: str
|
| 40 |
is_gold: bool = False
|
| 41 |
-
is_duplicate: bool = False
|
| 42 |
tokens: int = 0
|
|
|
|
| 43 |
|
| 44 |
class PruningState(State):
|
| 45 |
"""
|
| 46 |
-
Internal state
|
| 47 |
"""
|
| 48 |
task_name: str
|
| 49 |
question: str
|
|
|
|
| 5 |
|
| 6 |
class ContextAction(Action):
|
| 7 |
"""
|
| 8 |
+
Action space: A binary mask of N values (1 = keep, 0 = prune).
|
| 9 |
"""
|
| 10 |
mask: List[int] = Field(
|
| 11 |
...,
|
| 12 |
+
min_items=1,
|
| 13 |
+
description="Binary mask of integers (0 or 1) indicating which chunks to keep."
|
|
|
|
| 14 |
)
|
| 15 |
|
| 16 |
class ContextObservation(Observation):
|
|
|
|
| 18 |
Observation provided to the agent.
|
| 19 |
"""
|
| 20 |
question: str
|
| 21 |
+
chunks: List[str] = Field(default_factory=list, description="Current context chunks.")
|
| 22 |
+
initial_token_count: int = 0
|
| 23 |
+
current_token_count: int = 0
|
| 24 |
task_name: str = ""
|
| 25 |
message: str = ""
|
| 26 |
|
| 27 |
class ContextReward(BaseModel):
|
| 28 |
"""
|
| 29 |
+
Detailed reward breakdown for Meta x Scaler audit.
|
| 30 |
"""
|
| 31 |
score: float = Field(0.0, ge=0.0, le=1.0, description="Overall task score (0 to 1).")
|
| 32 |
+
efficiency_reward: float = 0.0
|
| 33 |
+
accuracy_reward: float = 0.0
|
| 34 |
+
gold_penalty: float = 0.0
|
| 35 |
message: str = ""
|
| 36 |
|
| 37 |
class ChunkItem(BaseModel):
|
| 38 |
"""Internal representation of a context chunk."""
|
| 39 |
content: str
|
| 40 |
is_gold: bool = False
|
|
|
|
| 41 |
tokens: int = 0
|
| 42 |
+
is_duplicate: bool = False
|
| 43 |
|
| 44 |
class PruningState(State):
|
| 45 |
"""
|
| 46 |
+
Internal state for ContextPrune.
|
| 47 |
"""
|
| 48 |
task_name: str
|
| 49 |
question: str
|
context_pruning_env/server/app.py
CHANGED
|
@@ -1,18 +1,20 @@
|
|
| 1 |
import os
|
| 2 |
from openenv.core.env_server.http_server import create_fastapi_app
|
| 3 |
from context_pruning_env.env import ContextPruningEnv
|
| 4 |
-
from context_pruning_env.models import
|
| 5 |
|
| 6 |
app = create_fastapi_app(
|
| 7 |
ContextPruningEnv,
|
| 8 |
-
|
| 9 |
-
|
| 10 |
)
|
| 11 |
|
| 12 |
def main() -> None:
|
| 13 |
import uvicorn
|
| 14 |
-
port = int(os.environ.get("PORT", "
|
| 15 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 16 |
|
| 17 |
if __name__ == "__main__":
|
| 18 |
main()
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from openenv.core.env_server.http_server import create_fastapi_app
|
| 3 |
from context_pruning_env.env import ContextPruningEnv
|
| 4 |
+
from context_pruning_env.models import ContextAction, ContextObservation
|
| 5 |
|
| 6 |
app = create_fastapi_app(
|
| 7 |
ContextPruningEnv,
|
| 8 |
+
ContextAction,
|
| 9 |
+
ContextObservation,
|
| 10 |
)
|
| 11 |
|
| 12 |
def main() -> None:
|
| 13 |
import uvicorn
|
| 14 |
+
port = int(os.environ.get("PORT", "8000"))
|
| 15 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 16 |
|
| 17 |
if __name__ == "__main__":
|
| 18 |
main()
|
| 19 |
+
|
| 20 |
+
|
context_pruning_env/utils.py
CHANGED
|
@@ -35,57 +35,44 @@ class SQuADLoader:
|
|
| 35 |
|
| 36 |
chunks = []
|
| 37 |
|
| 38 |
-
if task_name == "
|
| 39 |
-
# 1 Gold
|
| 40 |
chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
chunks.append({"content": noise_entry["context"], "is_gold": False, "is_duplicate": False})
|
| 44 |
|
| 45 |
-
elif task_name == "
|
| 46 |
-
#
|
| 47 |
chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
|
| 53 |
-
elif task_name == "
|
| 54 |
-
#
|
| 55 |
-
#
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
if answer_text.lower() in s.lower() and gold_sentence is None:
|
| 63 |
-
gold_sentence = s
|
| 64 |
-
else:
|
| 65 |
-
other_sentences.append(s)
|
| 66 |
-
|
| 67 |
-
if gold_sentence is None:
|
| 68 |
-
# Fallback if answer spans multiple sentences or is not found cleanly
|
| 69 |
-
gold_sentence = sentences[0]
|
| 70 |
-
other_sentences = sentences[1:]
|
| 71 |
-
|
| 72 |
-
chunks.append({"content": gold_sentence, "is_gold": True, "is_duplicate": False})
|
| 73 |
|
| 74 |
-
#
|
| 75 |
-
random.shuffle(
|
| 76 |
-
for
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
else:
|
| 80 |
-
_, noise_entry = self._get_next_entry()
|
| 81 |
-
chunks.append({"content": noise_entry["context"][:100], "is_gold": False, "is_duplicate": False})
|
| 82 |
|
| 83 |
else:
|
| 84 |
-
# Default to
|
| 85 |
-
return self.get_episode("
|
| 86 |
|
| 87 |
-
# Shuffle
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
return question, chunks
|
| 90 |
|
| 91 |
def count_tokens(text: str) -> int:
|
|
|
|
| 35 |
|
| 36 |
chunks = []
|
| 37 |
|
| 38 |
+
if task_name == "noise_purge":
|
| 39 |
+
# Easy: 1 Gold + 1 Irrelevant
|
| 40 |
chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
|
| 41 |
+
_, noise_entry = self._get_next_entry()
|
| 42 |
+
chunks.append({"content": noise_entry["context"], "is_gold": False, "is_duplicate": False})
|
|
|
|
| 43 |
|
| 44 |
+
elif task_name == "dedupe_arena":
|
| 45 |
+
# Medium: 1 Gold + 2 Near-Duplicates (Simulated by repeating gold)
|
| 46 |
chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
|
| 47 |
+
# Duplicate 1: slightly modified or identical
|
| 48 |
+
chunks.append({"content": gold_context + " ", "is_gold": True, "is_duplicate": True})
|
| 49 |
+
# Duplicate 2: slightly modified
|
| 50 |
+
chunks.append({"content": "Actually, " + gold_context, "is_gold": True, "is_duplicate": True})
|
| 51 |
|
| 52 |
+
elif task_name == "signal_extract":
|
| 53 |
+
# Hard: 1 Long context (2,000+ words)
|
| 54 |
+
# We simulate this by taking 10 random SQuAD contexts and joining them.
|
| 55 |
+
# Only one contains the answer.
|
| 56 |
+
long_context_parts = []
|
| 57 |
+
long_context_parts.append(gold_context)
|
| 58 |
+
for _ in range(15): # ~15 chunks of ~150 words = ~2250 words
|
| 59 |
+
_, noise_entry = self._get_next_entry()
|
| 60 |
+
long_context_parts.append(noise_entry["context"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
# Shuffling the parts so the gold one isn't first
|
| 63 |
+
random.shuffle(long_context_parts)
|
| 64 |
+
for part in long_context_parts:
|
| 65 |
+
is_gold = (part == gold_context)
|
| 66 |
+
chunks.append({"content": part, "is_gold": is_gold, "is_duplicate": False})
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
else:
|
| 69 |
+
# Default to noise_purge
|
| 70 |
+
return self.get_episode("noise_purge")
|
| 71 |
|
| 72 |
+
# Shuffle chunks for non-signal tasks
|
| 73 |
+
if task_name != "signal_extract":
|
| 74 |
+
random.shuffle(chunks)
|
| 75 |
+
|
| 76 |
return question, chunks
|
| 77 |
|
| 78 |
def count_tokens(text: str) -> int:
|
inference.py
CHANGED
|
@@ -2,78 +2,91 @@ import os
|
|
| 2 |
import json
|
| 3 |
import logging
|
| 4 |
import re
|
| 5 |
-
|
|
|
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
from context_pruning_env.env import ContextPruningEnv
|
| 8 |
-
|
| 9 |
-
# Load API keys from .env
|
| 10 |
-
load_dotenv()
|
| 11 |
from context_pruning_env.models import ContextAction
|
| 12 |
|
| 13 |
-
# Setup
|
| 14 |
-
logging.basicConfig(level=logging.INFO)
|
| 15 |
-
logger = logging.getLogger(
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
|
| 22 |
-
def
|
| 23 |
-
if not
|
| 24 |
-
|
| 25 |
return
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# 2. Initialize Environment
|
| 31 |
-
env = ContextPruningEnv(squad_split="train")
|
| 32 |
-
|
| 33 |
-
# Run a few episodes across different tasks
|
| 34 |
-
tasks = ["noise_filter", "deduplication", "sentence_distillation"]
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
obs = env.reset(task_name=task_name)
|
| 41 |
-
print(f"<OBSERVATION>{obs.model_dump_json()}</OBSERVATION>")
|
| 42 |
|
| 43 |
-
|
| 44 |
prompt = (
|
|
|
|
| 45 |
f"Question: {obs.question}\n\n"
|
| 46 |
-
"
|
| 47 |
-
"where 1 means 'keep' and 0 means 'prune'. "
|
| 48 |
-
"Prioritize keeping the answer while removing noise and duplicates.\n"
|
| 49 |
-
f"Chunks: {json.dumps(obs.chunks, indent=2)}\n\n"
|
| 50 |
-
"Action format: [1, 0, 1, 1, 0]"
|
| 51 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
try:
|
| 54 |
-
response =
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
match = re.search(r"\[\s*([01])\s*,\s*([01])\s*,\s*([01])\s*,\s*([01])\s*,\s*([01])\s*\]", completion)
|
| 59 |
if match:
|
| 60 |
-
mask =
|
| 61 |
else:
|
| 62 |
-
|
| 63 |
-
mask = [1, 1, 1, 1, 1]
|
| 64 |
except Exception as e:
|
| 65 |
-
logger.error(f"
|
| 66 |
-
mask = [1
|
| 67 |
|
| 68 |
-
#
|
| 69 |
action = ContextAction(mask=mask)
|
| 70 |
-
print(f"<ACTION>{action.model_dump_json()}</ACTION>")
|
| 71 |
-
|
| 72 |
-
# 6. Step (Reward)
|
| 73 |
final_obs = env.step(action)
|
| 74 |
-
print(f"<REWARD>{json.dumps({'score': final_obs.reward, 'message': final_obs.message})}</REWARD>")
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
if __name__ == "__main__":
|
| 79 |
-
|
|
|
|
| 2 |
import json
|
| 3 |
import logging
|
| 4 |
import re
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
from openai import OpenAI
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from context_pruning_env.env import ContextPruningEnv
|
|
|
|
|
|
|
|
|
|
| 9 |
from context_pruning_env.models import ContextAction
|
| 10 |
|
| 11 |
+
# Setup mandatory log format for Meta x Scaler Evaluator
|
| 12 |
+
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
| 13 |
+
logger = logging.getLogger("evaluator")
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
|
| 17 |
+
# Mandatory Environment Variables
|
| 18 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/")
|
| 19 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gemini-1.5-flash")
|
| 20 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", os.environ.get("GOOGLE_API_KEY", ""))
|
| 21 |
|
| 22 |
+
def run_inference():
|
| 23 |
+
if not HF_TOKEN:
|
| 24 |
+
print("ERROR: HF_TOKEN (or GOOGLE_API_KEY) not found.")
|
| 25 |
return
|
| 26 |
|
| 27 |
+
client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
|
| 28 |
+
env = ContextPruningEnv()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
tasks = ["noise_purge", "dedupe_arena", "signal_extract"]
|
| 31 |
+
|
| 32 |
+
for task in tasks:
|
| 33 |
+
# [START] tag for automated evaluation
|
| 34 |
+
print(f"[START] task={task} env=contextprune model={MODEL_NAME}")
|
| 35 |
|
| 36 |
+
obs = env.reset(task_name=task)
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
step_n = 1
|
| 39 |
prompt = (
|
| 40 |
+
f"Task: {task}\n"
|
| 41 |
f"Question: {obs.question}\n\n"
|
| 42 |
+
"Chunks:\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
)
|
| 44 |
+
for i, c in enumerate(obs.chunks):
|
| 45 |
+
prompt += f"[{i}]: {c}\n"
|
| 46 |
+
|
| 47 |
+
prompt += "\nOutput ONLY a JSON list of indices (0 or 1) for each chunk. Example: [1, 0, 1]"
|
| 48 |
|
| 49 |
try:
|
| 50 |
+
response = client.chat.completions.create(
|
| 51 |
+
model=MODEL_NAME,
|
| 52 |
+
messages=[{"role": "user", "content": prompt}],
|
| 53 |
+
temperature=0.0
|
| 54 |
+
)
|
| 55 |
+
content = response.choices[0].message.content
|
| 56 |
|
| 57 |
+
match = re.search(r"\[([\d\s,]+)\]", content)
|
|
|
|
| 58 |
if match:
|
| 59 |
+
mask = json.loads(match.group(0))
|
| 60 |
else:
|
| 61 |
+
mask = [1] * len(obs.chunks)
|
|
|
|
| 62 |
except Exception as e:
|
| 63 |
+
logger.error(f"Inference Error: {e}")
|
| 64 |
+
mask = [1] * len(obs.chunks)
|
| 65 |
|
| 66 |
+
# Execute Action
|
| 67 |
action = ContextAction(mask=mask)
|
|
|
|
|
|
|
|
|
|
| 68 |
final_obs = env.step(action)
|
|
|
|
| 69 |
|
| 70 |
+
# [STEP] tag for each action in the trajectory
|
| 71 |
+
step_log = (
|
| 72 |
+
f"[STEP] task={task} "
|
| 73 |
+
f"step={step_n} "
|
| 74 |
+
f"action={json.dumps(mask)} "
|
| 75 |
+
f"reward={final_obs.reward:.2f} "
|
| 76 |
+
f"done={str(final_obs.done).lower()}"
|
| 77 |
+
)
|
| 78 |
+
print(step_log)
|
| 79 |
+
|
| 80 |
+
# [END] tag for episode completion
|
| 81 |
+
score = final_obs.metadata.get('eval_score', 0)
|
| 82 |
+
success = score > 0.5
|
| 83 |
+
end_log = (
|
| 84 |
+
f"[END] task={task} "
|
| 85 |
+
f"score={score:.2f} "
|
| 86 |
+
f"success={str(success).lower()} "
|
| 87 |
+
f"rewards={final_obs.reward:.2f}"
|
| 88 |
+
)
|
| 89 |
+
print(end_log)
|
| 90 |
|
| 91 |
if __name__ == "__main__":
|
| 92 |
+
run_inference()
|
openenv.yaml
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
spec_version: 1
|
| 2 |
-
name:
|
|
|
|
| 3 |
type: space
|
| 4 |
-
runtime:
|
| 5 |
app: context_pruning_env.server.app:app
|
| 6 |
-
port:
|
| 7 |
resources:
|
| 8 |
cpu: 2
|
| 9 |
memory: 8Gi
|
| 10 |
storage: 10Gi
|
| 11 |
timeout: 300
|
| 12 |
-
description: "Adaptive Context Optimization
|
|
|
|
| 1 |
spec_version: 1
|
| 2 |
+
name: contextprune
|
| 3 |
+
version: 0.1.0
|
| 4 |
type: space
|
| 5 |
+
runtime: python
|
| 6 |
app: context_pruning_env.server.app:app
|
| 7 |
+
port: 8000
|
| 8 |
resources:
|
| 9 |
cpu: 2
|
| 10 |
memory: 8Gi
|
| 11 |
storage: 10Gi
|
| 12 |
timeout: 300
|
| 13 |
+
description: "ContextPrune: Adaptive Context Optimization Environment (Meta x Scaler Round 1 Compliance)."
|
requirements.txt
CHANGED
|
@@ -11,4 +11,4 @@ python-dotenv>=1.0.0
|
|
| 11 |
pytest>=7.4.0
|
| 12 |
gradio>=4.0.0
|
| 13 |
google-generativeai>=0.3.0
|
| 14 |
-
|
|
|
|
| 11 |
pytest>=7.4.0
|
| 12 |
gradio>=4.0.0
|
| 13 |
google-generativeai>=0.3.0
|
| 14 |
+
openai>=1.0.0
|
stderr.log
ADDED
|
Binary file (3.9 kB). View file
|
|
|
stderr_utf8.log
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python : D:\Projects\RAG\context_pr
|
| 2 |
+
uning_env\models.py:10:
|
| 3 |
+
PydanticDeprecatedSince20:
|
| 4 |
+
`min_items` is deprecated and will
|
| 5 |
+
be removed, use `min_length`
|
| 6 |
+
instead. Deprecated in Pydantic
|
| 7 |
+
V2.0 to be removed in V3.0. See
|
| 8 |
+
Pydantic V2 Migration Guide at http
|
| 9 |
+
s://errors.pydantic.dev/2.12/migrat
|
| 10 |
+
ion/
|
| 11 |
+
At line:1 char:1
|
| 12 |
+
+ python inference.py 2> stderr.log
|
| 13 |
+
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 14 |
+
+ CategoryInfo : NotS
|
| 15 |
+
pecified: (D:\Projects\RAG...2
|
| 16 |
+
.12/migration/:String) [], Rem
|
| 17 |
+
oteException
|
| 18 |
+
+ FullyQualifiedErrorId : Nati
|
| 19 |
+
veCommandError
|
| 20 |
+
|
| 21 |
+
mask: List[int] = Field(
|
| 22 |
+
HTTP Request: POST https://generati
|
| 23 |
+
velanguage.googleapis.com/v1beta/op
|
| 24 |
+
enai/chat/completions "HTTP/1.1
|
| 25 |
+
404 Not Found"
|
| 26 |
+
Inference Error: Error code: 404 -
|
| 27 |
+
[{'error': {'code': 404,
|
| 28 |
+
'message':
|
| 29 |
+
'models/gemini-1.5-flash is not
|
| 30 |
+
found for API version v1main, or
|
| 31 |
+
is not supported for
|
| 32 |
+
generateContent. Call ListModels
|
| 33 |
+
to see the list of available
|
| 34 |
+
models and their supported
|
| 35 |
+
methods.', 'status': 'NOT_FOUND'}}]
|
| 36 |
+
HTTP Request: POST https://generati
|
| 37 |
+
velanguage.googleapis.com/v1beta/op
|
| 38 |
+
enai/chat/completions "HTTP/1.1
|
| 39 |
+
404 Not Found"
|
| 40 |
+
Inference Error: Error code: 404 -
|
| 41 |
+
[{'error': {'code': 404,
|
| 42 |
+
'message':
|
| 43 |
+
'models/gemini-1.5-flash is not
|
| 44 |
+
found for API version v1main, or
|
| 45 |
+
is not supported for
|
| 46 |
+
generateContent. Call ListModels
|
| 47 |
+
to see the list of available
|
| 48 |
+
models and their supported
|
| 49 |
+
methods.', 'status': 'NOT_FOUND'}}]
|
| 50 |
+
HTTP Request: POST https://generati
|
| 51 |
+
velanguage.googleapis.com/v1beta/op
|
| 52 |
+
enai/chat/completions "HTTP/1.1
|
| 53 |
+
404 Not Found"
|
| 54 |
+
Inference Error: Error code: 404 -
|
| 55 |
+
[{'error': {'code': 404,
|
| 56 |
+
'message':
|
| 57 |
+
'models/gemini-1.5-flash is not
|
| 58 |
+
found for API version v1main, or
|
| 59 |
+
is not supported for
|
| 60 |
+
generateContent. Call ListModels
|
| 61 |
+
to see the list of available
|
| 62 |
+
models and their supported
|
| 63 |
+
methods.', 'status': 'NOT_FOUND'}}]
|
stdout.log
ADDED
|
Binary file (1.05 kB). View file
|
|
|
stdout_utf8.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
task=noise_purge env=contextprune model=gemini-1.5-flash step=1 action=[1, 1] reward=0.70 done=true error=null success=false steps=1 score=0.50 rewards=0.70
|
| 2 |
+
task=dedupe_arena env=contextprune model=gemini-1.5-flash step=1 action=[1, 1, 1] reward=0.70 done=true error=null success=false steps=1 score=0.50 rewards=0.70
|
| 3 |
+
task=signal_extract env=contextprune model=gemini-1.5-flash step=1 action=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] reward=0.70 done=true error=null success=false steps=1 score=0.10 rewards=0.70
|