import os import re import json import time import asyncio import gradio as gr from google import genai from dotenv import load_dotenv from typing import List, Tuple from context_pruning_env.utils import count_tokens # Load API keys from .env load_dotenv() # --- Configuration --- API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") client = genai.Client(api_key=API_KEY) # Fallback sequence for 2026 availability & quota limits MODEL_SEQUENCE = [ os.environ.get("MODEL_NAME", "gemini-2.0-flash"), "gemini-2.5-flash", "gemini-3.1-flash-live-preview", "gemini-1.5-flash-8b" ] def call_gemini_with_retry(prompt: str) -> str: """Helper to call Gemini with exponential backoff and model fallback.""" if not API_KEY: return "ERROR: API Key not found." for model_name in MODEL_SEQUENCE: retries = 2 backoff = 3 for attempt in range(retries): try: response = client.models.generate_content( model=model_name, config={ 'temperature': 0.1, 'top_p': 0.95, 'max_output_tokens': 512, }, contents=prompt ) if response and response.text: return response.text except Exception as e: err_str = str(e).lower() if "429" in err_str or "quota" in err_str: time.sleep(backoff) backoff *= 2 else: break # Try next model return "ERROR: All models hit quota or failed." def chunk_text(text: str, max_chunks: int = 20) -> List[str]: """Split text into chunks.""" initial_chunks = [c.strip() for c in re.split(r'\n\s*\n', text) if c.strip()] final_chunks = [] for chunk in initial_chunks: sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+|\n', chunk) if s.strip()] final_chunks.extend(sentences) return final_chunks[:max_chunks] async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]: """Pruning logic with robust retry wrapper.""" if not query or not raw_text: return "Please provide both.", {}, "" chunks = chunk_text(raw_text) selection_prompt = ( f"Query: {query}\n\n" "TASK: AGGRESSIVE CONTEXT OPTIMIZATION. " "Goal: TOKEN REDUCTION. Prune noise and keep ONLY essential info.\n" f"OUTPUT: Output EXACTLY {len(chunks)} binary integers [0 or 1] as a JSON list.\n\n" "Chunks:\n" ) for i, c in enumerate(chunks): selection_prompt += f"Chunk {i}: {c}\n" loop = asyncio.get_event_loop() raw_response = await loop.run_in_executor(None, call_gemini_with_retry, selection_prompt) if "ERROR" in raw_response: return raw_response, {}, "FAIL" indices = [] try: match = re.search(r"\[([\d\s,]+)\]", raw_response) if match: mask = json.loads(match.group(0)) mask = (mask + [0] * len(chunks))[:len(chunks)] indices = [i for i, m in enumerate(mask) if int(m) == 1] except: indices = [] if not indices: optimized_text = "No matches found or optimization too aggressive." else: optimized_text = " ".join([chunks[i] for i in sorted(indices)]) orig_tokens = count_tokens(raw_text) final_tokens = count_tokens(optimized_text) reduction = ((orig_tokens - final_tokens) / orig_tokens * 100) if orig_tokens > 0 else 0 metrics = { "Original Tokens": f"{orig_tokens}", "Final Tokens": f"{final_tokens}", "Reduction Score": f"{reduction:.1f}%" } ground_prompt = f"Question: {query}\nContext: {optimized_text}\n\nTask: Response with 'PASS' if info present, else 'FAIL'." ground_result = await loop.run_in_executor(None, call_gemini_with_retry, ground_prompt) return optimized_text, metrics, ground_result # --- Gradio UI with Premium Styling --- def get_status_html(result: str): if "PASS" in result.upper(): return '
🚀 GROUNDEDNESS SUCCESS
' return '
⚠️ GROUNDEDNESS FAILURE
' CSS = """ body { background-color: #0f172a; color: white; } .gradio-container { border-radius: 20px !important; box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5) !important; } #title { text-align: center; font-size: 2.5em; margin-bottom: 20px; color: #38bdf8; } """ with gr.Blocks(title="ContextPrune") as demo: gr.Markdown("# 🧠 ContextPrune AI: Quota-Resilient Context Compression", elem_id="title") with gr.Tabs(): with gr.TabItem("Optimizer"): with gr.Row(): with gr.Column(scale=2): query_in = gr.Textbox(label="🔍 User Query", placeholder="What are the key technical findings?", lines=2) context_in = gr.Textbox(label="📄 Noisy Document Content", placeholder="Paste large blocks of text here...", lines=15) btn = gr.Button("🔥 Prune Context Now", variant="primary", size="lg") with gr.Column(scale=1): metrics_lbl = gr.Label(label="Optimization Efficiency") status = gr.HTML() out = gr.Textbox(label="✨ Optimized Context (Ready for LLM)", interactive=False, lines=15) async def run_ui(q, c): txt, m, g = await prune_context(q, c) return txt, get_status_html(g), m btn.click(run_ui, [query_in, context_in], [out, status, metrics_lbl]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Default(primary_hue="blue", neutral_hue="slate"), css=CSS)