File size: 4,085 Bytes
0b89610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import json
import re
import time
from google import genai
from dotenv import load_dotenv
from context_pruning_env.env import ContextPruningEnv
from context_pruning_env.models import ContextAction

# Load .env
load_dotenv()

# --- SDK CONFIGURATION ---
API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
client = genai.Client(api_key=API_KEY)

# Fallback sequence for 2026 availability & quota limits
MODEL_SEQUENCE = [
    os.environ.get("MODEL_NAME", "gemini-2.0-flash"), 
    "gemini-2.5-flash",
    "gemini-3.1-flash-live-preview",
    "gemini-1.5-flash-8b"
]

def call_with_retry(prompt):
    """Calls Gemini with exponential backoff and model fallback for 429 errors."""
    for model_name in MODEL_SEQUENCE:
        retries = 3
        backoff = 5 # Start with 5s for 2026 free tier
        
        for attempt in range(retries):
            try:
                print(f"DEBUG: Attempting {model_name} (Attempt {attempt+1}/{retries})...")
                response = client.models.generate_content(
                    model=model_name,
                    config={
                        'temperature': 0.1,
                        'top_p': 0.95,
                        'max_output_tokens': 512,
                    },
                    contents=prompt
                )
                if response and response.text:
                    return response.text, model_name
                
            except Exception as e:
                err_str = str(e).lower()
                if "429" in err_str or "quota" in err_str or "resource" in err_str:
                    print(f"DEBUG: QUOTA EXCEEDED for {model_name}. Retrying in {backoff}s...")
                    time.sleep(backoff)
                    backoff *= 2
                elif "404" in err_str or "not found" in err_str:
                    print(f"DEBUG: MODEL {model_name} NOT FOUND. Falling back to next model.")
                    break # Try next model in sequence
                else:
                    print(f"LOUD ERROR: {e}")
                    # If it's a non-retryable error, we still try the next model
                    break
                    
    return None, None

def run_inference():
    env = ContextPruningEnv()
    tasks = ["noise_purge", "dedupe_arena", "signal_extract"]
    
    total_score = 0
    for task in tasks:
        print(f"\n--- Starting Task: {task} ---")
        print(f"[START] task={task}")
        obs = env.reset(task_name=task)
        
        # PROMPT
        prompt = (
            f"Query: {obs.question}\n\n"
            f"TASKS: Prune the following {len(obs.chunks)} chunks. Output EXACTLY {len(obs.chunks)} binary integers [0 or 1] as a JSON list.\n"
            "Chunks:\n"
        )
        for i, c in enumerate(obs.chunks):
            prompt += f"[{i}]: {c}\n"

        raw_text, used_model = call_with_retry(prompt)
        
        mask = [1] * len(obs.chunks) # Default fallback
        if raw_text:
            print(f"DEBUG: Used Model: {used_model} | RAW RESP: {raw_text}")
            # Extract list
            match = re.search(r"\[([\d\s,]+)\]", raw_text)
            if match:
                try:
                    mask = json.loads(match.group(0))
                except:
                    # manual parse if json fails
                    mask = [int(x) for x in re.findall(r'[01]', match.group(1))]
            
            # Robust Pad/truncate
            mask = (mask + [1] * len(obs.chunks))[:len(obs.chunks)]
        else:
            print(f"DEBUG: ALL MODELS FAILED for {task}. Using identity mask.")

        action = ContextAction(mask=mask)
        obs = env.step(action)
        
        score = obs.metadata.get("eval_score", 0.0)
        print(f"[STEP] reward={getattr(obs, 'reward', 0.0):.2f} mask={mask}")
        print(f"[END] task={task} score={score:.2f} success={str(score > 0.5).lower()}")
        total_score += score
        
    print(f"\nINFO:--- ALL TASKS COMPLETE. FINAL AVG SCORE: {total_score / len(tasks):.2f} ---")

if __name__ == "__main__":
    run_inference()