import os import re import json import time import asyncio import gradio as gr from google import genai from dotenv import load_dotenv from typing import List, Tuple from context_pruning_env.utils import count_tokens # Load API keys from .env load_dotenv() # --- Configuration --- API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") client = genai.Client(api_key=API_KEY) # Fallback sequence for 2026 availability & quota limits MODEL_SEQUENCE = [ os.environ.get("MODEL_NAME", "gemini-2.0-flash"), "gemini-2.5-flash", "gemini-3.1-flash-live-preview", "gemini-1.5-flash-8b" ] def call_gemini_with_retry(prompt: str) -> str: """Helper to call Gemini with exponential backoff and model fallback.""" if not API_KEY: return "ERROR: API Key not found." for model_name in MODEL_SEQUENCE: retries = 2 backoff = 3 for attempt in range(retries): try: response = client.models.generate_content( model=model_name, config={ 'temperature': 0.1, 'top_p': 0.95, 'max_output_tokens': 512, }, contents=prompt ) if response and response.text: return response.text except Exception as e: err_str = str(e).lower() if "429" in err_str or "quota" in err_str: time.sleep(backoff) backoff *= 2 else: break # Try next model return "ERROR: All models hit quota or failed." def chunk_text(text: str, max_chunks: int = 20) -> List[str]: """Split text into chunks.""" initial_chunks = [c.strip() for c in re.split(r'\n\s*\n', text) if c.strip()] final_chunks = [] for chunk in initial_chunks: sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+|\n', chunk) if s.strip()] final_chunks.extend(sentences) return final_chunks[:max_chunks] async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]: """Pruning logic with robust retry wrapper.""" if not query or not raw_text: return "Please provide both.", {}, "" chunks = chunk_text(raw_text) selection_prompt = ( f"Query: {query}\n\n" "TASK: AGGRESSIVE CONTEXT OPTIMIZATION. " "Goal: TOKEN REDUCTION. Prune noise and keep ONLY essential info.\n" f"OUTPUT: Output EXACTLY {len(chunks)} binary integers [0 or 1] as a JSON list.\n\n" "Chunks:\n" ) for i, c in enumerate(chunks): selection_prompt += f"Chunk {i}: {c}\n" loop = asyncio.get_event_loop() raw_response = await loop.run_in_executor(None, call_gemini_with_retry, selection_prompt) if "ERROR" in raw_response: return raw_response, {}, "FAIL" indices = [] try: match = re.search(r"\[([\d\s,]+)\]", raw_response) if match: mask = json.loads(match.group(0)) mask = (mask + [0] * len(chunks))[:len(chunks)] indices = [i for i, m in enumerate(mask) if int(m) == 1] except: indices = [] if not indices: optimized_text = "No matches found or optimization too aggressive." else: optimized_text = " ".join([chunks[i] for i in sorted(indices)]) orig_tokens = count_tokens(raw_text) final_tokens = count_tokens(optimized_text) reduction = ((orig_tokens - final_tokens) / orig_tokens * 100) if orig_tokens > 0 else 0 metrics = { "Original Tokens": f"{orig_tokens}", "Final Tokens": f"{final_tokens}", "Reduction Score": f"{reduction:.1f}%" } ground_prompt = f"Question: {query}\nContext: {optimized_text}\n\nTask: Response with 'PASS' if info present, else 'FAIL'." ground_result = await loop.run_in_executor(None, call_gemini_with_retry, ground_prompt) return optimized_text, metrics, ground_result # --- Gradio UI with Premium Styling --- def get_status_html(result: str): if "PASS" in result.upper(): return '