gridworld-env / OpenEnv /examples /repl_oolong_simple.py
Abhilasha Kakoty
Initial deploy
7078f4d
#!/usr/bin/env python3
"""
Simple REPL + Oolong example with recursive LLM calls (RLM paradigm).
Uses LocalRLMRunner which handles both the outer loop (code generation)
and inner calls (llm_query/rlm_query) with a single chat function.
Usage:
python examples/repl_oolong_simple.py
"""
from __future__ import annotations
import os
from datasets import load_dataset
from huggingface_hub import InferenceClient
from repl_env import LocalRLMRunner
from repl_env.prompts import RLM_SYSTEM_PROMPT_QWEN
# ============== CONFIGURATION ==============
MODEL_NAME = "Qwen/Qwen3-Coder-480B-A35B-Instruct"
DATASET_SUBSET = "toy_dnd"
DATASET_SPLIT = "validation"
EXAMPLE_INDEX = 0
MAX_ITERATIONS = 30 # Paper uses 30
# ===========================================
HF_TOKEN = os.environ.get("HF_TOKEN")
def create_chat_fn():
"""Create the chat function with Qwen3-Coder recommended params."""
client = InferenceClient(model=MODEL_NAME, token=HF_TOKEN, timeout=300)
def chat_fn(messages: list[dict], model: str | None = None) -> str:
response = client.chat.completions.create(
model=model or MODEL_NAME,
messages=messages,
# Qwen3-Coder-480B sampling params (from model card)
max_tokens=1024,
temperature=0.7,
top_p=0.8,
extra_body={
"top_k": 20,
"repetition_penalty": 1.05,
},
)
return response.choices[0].message.content
return chat_fn
def main():
print("=" * 60)
print("REPL + Oolong with Recursive LLM Calls (RLM)")
print("=" * 60)
# Load dataset
print(f"\nLoading dataset example {EXAMPLE_INDEX}...")
dataset = load_dataset(
"oolongbench/oolong-real", DATASET_SUBSET, split=DATASET_SPLIT
)
example = dataset[EXAMPLE_INDEX]
context = example["context_window_text"]
question = example["question"]
expected = str(example["answer"])
print(f"Question: {question}")
print(f"Expected answer: {expected}")
print(f"Context length: {len(context):,} chars")
# Create LLM function — used for both outer loop and inner llm_query calls
chat_fn = create_chat_fn()
# Run the RLM loop
runner = LocalRLMRunner(
chat_fn,
system_prompt=RLM_SYSTEM_PROMPT_QWEN,
max_iterations=MAX_ITERATIONS,
max_depth=2,
verbose=True,
)
result = runner.run(context, question)
# Results
print("\n" + "=" * 60)
print("RESULTS")
print("=" * 60)
print(f"Question: {question}")
print(f"Expected: {expected}")
print(f"Got: {result.final_answer}")
print(f"Iterations: {result.iterations}")
if (
result.final_answer
and str(result.final_answer).strip().lower() == expected.strip().lower()
):
print("CORRECT!")
else:
print("INCORRECT")
if __name__ == "__main__":
main()