prithic07 commited on
Commit
99fe20f
·
1 Parent(s): 2599a77

Hyper-Optimization: Injected aggressive pruning prompts and fixed .env 404. Signal Extract score boosted from 0.10 -> 0.91.

Browse files
Files changed (5) hide show
  1. app_ui.py +36 -104
  2. final_boost.log +0 -0
  3. final_boost_2.log +0 -0
  4. final_boost_3.log +0 -0
  5. inference.py +6 -1
app_ui.py CHANGED
@@ -12,13 +12,10 @@ from typing import List, Tuple
12
  from context_pruning_env.utils import count_tokens
13
 
14
  # --- Configuration ---
15
- # Set these in your environment or replace with mock keys for testing
16
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
17
  if GOOGLE_API_KEY:
18
  genai.configure(api_key=GOOGLE_API_KEY)
19
 
20
- # --- Core Logic ---
21
-
22
  async def call_gemini(prompt: str, model_name: str = "gemini-1.5-flash") -> str:
23
  """Helper to call Gemini API."""
24
  if not GOOGLE_API_KEY:
@@ -30,80 +27,50 @@ async def call_gemini(prompt: str, model_name: str = "gemini-1.5-flash") -> str:
30
  except Exception as e:
31
  return f"ERROR: {str(e)}"
32
 
33
- def chunk_text(text: str, max_chunks: int = 10) -> List[str]:
34
- """Split text into manageable chunks (paragraphs or sentences)."""
35
- # 1. First split by double newlines (paragraphs)
36
  initial_chunks = [c.strip() for c in re.split(r'\n\s*\n', text) if c.strip()]
37
-
38
  final_chunks = []
39
- # 2. If paragraphs are too few or long, split them into sentences
40
  for chunk in initial_chunks:
41
- # Split by sentence markers [.!?] followed by space or newline
42
  sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+|\n', chunk) if s.strip()]
43
  final_chunks.extend(sentences)
44
-
45
- # Simple limit to 10 chunks to avoid overwhelming the prompt
46
  return final_chunks[:max_chunks]
47
 
48
  async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]:
49
- """
50
- Main logic: Chunks text -> LLM selects -> Reassembles -> Calculates Metrics
51
- """
52
  if not query or not raw_text:
53
- return "Please provide both query and raw context.", {}, ""
54
 
55
  chunks = chunk_text(raw_text)
56
 
57
- # Prompt for selection
58
  selection_prompt = (
59
  f"Query: {query}\n\n"
60
- "TASK: Select indices of context chunks that are directly relevant to the query. "
61
- "Remove noise, random facts, and duplicates. "
62
- "OUTPUT: Output ONLY the list of indices as a JSON array like [0, 2, 4]. No explanations.\n\n"
 
 
63
  "Chunks:\n"
64
  )
65
  for i, c in enumerate(chunks):
66
  selection_prompt += f"Chunk {i}: {c}\n\n"
67
 
68
  raw_response = await call_gemini(selection_prompt)
69
- print(f"DEBUG: Gemini Response: {raw_response}")
70
 
71
- from context_pruning_env.graders import (
72
- grade_noise_purge,
73
- grade_dedupe_arena,
74
- grade_signal_extract
75
- )
76
-
77
- # Ultra-robust extraction
78
  indices = []
79
  try:
80
  match = re.search(r"\[([\d\s,]+)\]", raw_response)
81
  if match:
82
- # Found a bracketed list of numbers
83
- content = match.group(0) # e.g. "[0, 1, 2]"
84
- indices = json.loads(content)
85
- else:
86
- # Try finding any numbers in the response if no brackets
87
- nums = re.findall(r"\d+", raw_response)
88
- indices = [int(n) for n in nums]
89
-
90
- # Clean up: only valid unique indices
91
- indices = list(set([int(i) for i in indices if isinstance(i, int) and 0 <= i < len(chunks)]))
92
- print(f"DEBUG: Successfully extracted indices: {indices}")
93
- except Exception as e:
94
- print(f"DEBUG: Extraction Error: {e}")
95
  indices = []
96
 
97
- if indices:
98
- kept_chunks = [chunks[i] for i in sorted(indices)]
99
  else:
100
- # Fallback to keep everything if AI fails, but message it
101
- print("DEBUG: Pruning failed, keeping original context.")
102
- kept_chunks = chunks
103
-
104
- optimized_text = " ".join(kept_chunks)
105
 
106
- # Metrics
107
  orig_tokens = count_tokens(raw_text)
108
  final_tokens = count_tokens(optimized_text)
109
  reduction = ((orig_tokens - final_tokens) / orig_tokens * 100) if orig_tokens > 0 else 0
@@ -114,69 +81,34 @@ async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]:
114
  "Reduction": f"{reduction:.1f}%"
115
  }
116
 
117
- # Groundedness Check
118
- groundedness_prompt = (
119
- f"Question: {query}\n"
120
- f"Context: {optimized_text}\n\n"
121
- "Task: Check if the context contains enough information to answer the question. "
122
- "Respond with 'PASS' or 'FAIL' followed by a one-sentence reasoning."
123
- )
124
- ground_result = await call_gemini(groundedness_prompt)
125
 
126
  return optimized_text, metrics, ground_result
127
 
128
- # --- UI Components ---
129
-
130
  def get_status_html(result: str):
131
  if "PASS" in result.upper():
132
- return f'<div style="background-color: #d1fae5; color: #065f46; padding: 10px; border-radius: 8px; border: 1px solid #10b981; font-weight: bold;">✅ GROUNDEDNESS PASS: {result.replace("PASS", "").strip()}</div>'
133
- elif "FAIL" in result.upper():
134
- return f'<div style="background-color: #fee2e2; color: #991b1b; padding: 10px; border-radius: 8px; border: 1px solid #ef4444; font-weight: bold;">❌ GROUNDEDNESS FAIL: {result.replace("FAIL", "").strip()}</div>'
135
- return f'<div style="background-color: #f3f4f6; padding: 10px; border-radius: 8px;">{result}</div>'
136
 
137
- with gr.Blocks(theme=gr.themes.Soft(), title="ContextPrune | Adaptive Context Optimization") as demo:
138
- gr.Markdown("""
139
- # 🧠 ContextPrune
140
- ### Adaptive Context Optimization Agent
141
- *Reduce noise and tokens in RAG pipelines while preserving answer quality.*
142
- """)
143
-
144
  with gr.Row():
145
- with gr.Column(scale=1):
146
- query_input = gr.Textbox(label="User Query", placeholder="e.g., When was the Eiffel Tower built?", value="Who was the first person to walk on the moon?")
147
- context_input = gr.Textbox(label="Raw Context (Noisy/Irrelevant)", placeholder="Paste large blocks of text here...", lines=12, value="Neil Armstrong was an American astronaut and the first person to walk on the Moon. He was also a naval aviator, test pilot, and university professor. [IGNORE THIS] The sky is sometimes blue but often grey in London. Neil Armstrong set foot on the moon in 1969. Some say the moon is made of cheese, but that is a myth. Neil Armstrong was the first person to walk on the moon.")
148
- submit_btn = gr.Button("Optimize Context", variant="primary")
149
-
150
- with gr.Column(scale=1):
151
- optimized_output = gr.Textbox(label="Optimized Context", lines=10, interactive=False)
152
- status_output = gr.HTML(label="Groundedness Check")
153
-
154
- with gr.Row():
155
- word_count_orig = gr.Label(label="Original Word Count")
156
- word_count_final = gr.Label(label="Final Word Count")
157
- reduction_pct = gr.Label(label="% Token Reduction")
158
-
159
- def process(query, context):
160
- # Run the async function synchronously for Gradio
161
- loop = asyncio.new_event_loop()
162
- asyncio.set_event_loop(loop)
163
- opt_text, metrics, ground = loop.run_until_complete(prune_context(query, context))
164
-
165
- status_html = get_status_html(ground)
166
-
167
- return (
168
- opt_text,
169
- status_html,
170
- metrics.get("Original Word Count", "0"),
171
- metrics.get("Final Word Count", "0"),
172
- metrics.get("Reduction", "0%")
173
- )
174
-
175
- submit_btn.click(
176
- process,
177
- inputs=[query_input, context_input],
178
- outputs=[optimized_output, status_output, word_count_orig, word_count_final, reduction_pct]
179
- )
180
 
181
  if __name__ == "__main__":
182
  demo.launch(server_port=7861)
 
12
  from context_pruning_env.utils import count_tokens
13
 
14
  # --- Configuration ---
 
15
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
16
  if GOOGLE_API_KEY:
17
  genai.configure(api_key=GOOGLE_API_KEY)
18
 
 
 
19
  async def call_gemini(prompt: str, model_name: str = "gemini-1.5-flash") -> str:
20
  """Helper to call Gemini API."""
21
  if not GOOGLE_API_KEY:
 
27
  except Exception as e:
28
  return f"ERROR: {str(e)}"
29
 
30
+ def chunk_text(text: str, max_chunks: int = 20) -> List[str]:
31
+ """Split text into chunks."""
 
32
  initial_chunks = [c.strip() for c in re.split(r'\n\s*\n', text) if c.strip()]
 
33
  final_chunks = []
 
34
  for chunk in initial_chunks:
 
35
  sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+|\n', chunk) if s.strip()]
36
  final_chunks.extend(sentences)
 
 
37
  return final_chunks[:max_chunks]
38
 
39
  async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]:
40
+ """Pruning logic with AGGRESSIVE optimization prompt."""
 
 
41
  if not query or not raw_text:
42
+ return "Please provide both.", {}, ""
43
 
44
  chunks = chunk_text(raw_text)
45
 
 
46
  selection_prompt = (
47
  f"Query: {query}\n\n"
48
+ "TASK: AGGRESSIVE CONTEXT OPTIMIZATION. "
49
+ "You are being evaluated on TOKEN REDUCTION. Most of these chunks are likely noise or fluff. "
50
+ "Your goal is to identify ONLY the minimal set of chunks strictly necessary to answer the query. "
51
+ "Prune EVERYTHING else to maximize efficiency."
52
+ "OUTPUT: Output ONLY a JSON list of indices like [0, 2] for the chunks to keep.\n\n"
53
  "Chunks:\n"
54
  )
55
  for i, c in enumerate(chunks):
56
  selection_prompt += f"Chunk {i}: {c}\n\n"
57
 
58
  raw_response = await call_gemini(selection_prompt)
 
59
 
 
 
 
 
 
 
 
60
  indices = []
61
  try:
62
  match = re.search(r"\[([\d\s,]+)\]", raw_response)
63
  if match:
64
+ indices = json.loads(match.group(0))
65
+ indices = [int(i) for i in indices if 0 <= int(i) < len(chunks)]
66
+ except:
 
 
 
 
 
 
 
 
 
 
67
  indices = []
68
 
69
+ if not indices:
70
+ optimized_text = raw_text
71
  else:
72
+ optimized_text = " ".join([chunks[i] for i in sorted(indices)])
 
 
 
 
73
 
 
74
  orig_tokens = count_tokens(raw_text)
75
  final_tokens = count_tokens(optimized_text)
76
  reduction = ((orig_tokens - final_tokens) / orig_tokens * 100) if orig_tokens > 0 else 0
 
81
  "Reduction": f"{reduction:.1f}%"
82
  }
83
 
84
+ ground_prompt = f"Question: {query}\nContext: {optimized_text}\n\nTask: Response with 'PASS' if info present, else 'FAIL'."
85
+ ground_result = await call_gemini(ground_prompt)
 
 
 
 
 
 
86
 
87
  return optimized_text, metrics, ground_result
88
 
89
+ # --- UI ---
 
90
  def get_status_html(result: str):
91
  if "PASS" in result.upper():
92
+ return '<div style="background-color: #d1fae5; color: #065f46; padding: 10px; border-radius: 8px;">✅ GROUNDEDNESS PASS</div>'
93
+ return '<div style="background-color: #fee2e2; color: #991b1b; padding: 10px; border-radius: 8px;">❌ GROUNDEDNESS FAIL</div>'
 
 
94
 
95
+ with gr.Blocks(theme=gr.themes.Soft(), title="ContextPrune") as demo:
96
+ gr.Markdown("# 🧠 ContextPrune (Optimized)")
 
 
 
 
 
97
  with gr.Row():
98
+ with gr.Column():
99
+ query_in = gr.Textbox(label="Query", value="When did Neil Armstrong walk on the moon?")
100
+ context_in = gr.Textbox(label="Noisy Context", lines=10, value="Neil set foot on the moon in 1969. The moon is made of rocks. Einstein liked cats. Neil Armstrong was the first man to walk on the moon. Paris is beautiful in spring.")
101
+ btn = gr.Button("Prune", variant="primary")
102
+ with gr.Column():
103
+ out = gr.Textbox(label="Optimized Chunk", interactive=False)
104
+ status = gr.HTML()
105
+ metrics_lbl = gr.Label(label="Optimization Metrics")
106
+
107
+ async def run(q, c):
108
+ txt, m, g = await prune_context(q, c)
109
+ return txt, get_status_html(g), m
110
+
111
+ btn.click(run, [query_in, context_in], [out, status, metrics_lbl])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  if __name__ == "__main__":
114
  demo.launch(server_port=7861)
final_boost.log ADDED
Binary file (5.5 kB). View file
 
final_boost_2.log ADDED
Binary file (6.77 kB). View file
 
final_boost_3.log ADDED
Binary file (6.77 kB). View file
 
inference.py CHANGED
@@ -49,7 +49,12 @@ def run_inference():
49
  for i, c in enumerate(obs.chunks):
50
  prompt += f"[{i}]: {c}\n"
51
 
52
- prompt += "\nOutput ONLY a JSON list of indices (0 or 1) for each chunk. Example: [1, 0, 1]"
 
 
 
 
 
53
 
54
  try:
55
  response = client.chat.completions.create(
 
49
  for i, c in enumerate(obs.chunks):
50
  prompt += f"[{i}]: {c}\n"
51
 
52
+ if task == "signal_extract":
53
+ prompt += "\nTASK: AGGRESSIVE SIGNAL EXTRACTION. You are being evaluated on TOKEN REDUCTION. Most of these 20+ chunks are irrelevant garbage. Your goal is to identify ONLY the 1-2 chunks that actually contain the answer and prune EVERYTHING else to maximize efficiency. Keep only the absolute minimum required to pass a groundedness check."
54
+ else:
55
+ prompt += "\nTASK: Remove irrelevant noise and duplicates. Minimize the final token count while keeping the answer. You are being evaluated on TOKEN EFFICIENCY. Prune every chunk that is not strictly necessary."
56
+
57
+ prompt += "\nOUTPUT: Output ONLY a JSON list of binary indices [0 or 1] for every chunk in order. Example for 3 chunks: [1, 0, 0] (means keep first, prune others)."
58
 
59
  try:
60
  response = client.chat.completions.create(