| |
| """ |
| Simple REPL + Oolong example with recursive LLM calls (RLM paradigm). |
| |
| Uses LocalRLMRunner which handles both the outer loop (code generation) |
| and inner calls (llm_query/rlm_query) with a single chat function. |
| |
| Usage: |
| python examples/repl_oolong_simple.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
|
|
| from datasets import load_dataset |
| from huggingface_hub import InferenceClient |
| from repl_env import LocalRLMRunner |
| from repl_env.prompts import RLM_SYSTEM_PROMPT_QWEN |
|
|
| |
| MODEL_NAME = "Qwen/Qwen3-Coder-480B-A35B-Instruct" |
| DATASET_SUBSET = "toy_dnd" |
| DATASET_SPLIT = "validation" |
| EXAMPLE_INDEX = 0 |
| MAX_ITERATIONS = 30 |
| |
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
| def create_chat_fn(): |
| """Create the chat function with Qwen3-Coder recommended params.""" |
| client = InferenceClient(model=MODEL_NAME, token=HF_TOKEN, timeout=300) |
|
|
| def chat_fn(messages: list[dict], model: str | None = None) -> str: |
| response = client.chat.completions.create( |
| model=model or MODEL_NAME, |
| messages=messages, |
| |
| max_tokens=1024, |
| temperature=0.7, |
| top_p=0.8, |
| extra_body={ |
| "top_k": 20, |
| "repetition_penalty": 1.05, |
| }, |
| ) |
| return response.choices[0].message.content |
|
|
| return chat_fn |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("REPL + Oolong with Recursive LLM Calls (RLM)") |
| print("=" * 60) |
|
|
| |
| print(f"\nLoading dataset example {EXAMPLE_INDEX}...") |
| dataset = load_dataset( |
| "oolongbench/oolong-real", DATASET_SUBSET, split=DATASET_SPLIT |
| ) |
| example = dataset[EXAMPLE_INDEX] |
|
|
| context = example["context_window_text"] |
| question = example["question"] |
| expected = str(example["answer"]) |
|
|
| print(f"Question: {question}") |
| print(f"Expected answer: {expected}") |
| print(f"Context length: {len(context):,} chars") |
|
|
| |
| chat_fn = create_chat_fn() |
|
|
| |
| runner = LocalRLMRunner( |
| chat_fn, |
| system_prompt=RLM_SYSTEM_PROMPT_QWEN, |
| max_iterations=MAX_ITERATIONS, |
| max_depth=2, |
| verbose=True, |
| ) |
| result = runner.run(context, question) |
|
|
| |
| print("\n" + "=" * 60) |
| print("RESULTS") |
| print("=" * 60) |
| print(f"Question: {question}") |
| print(f"Expected: {expected}") |
| print(f"Got: {result.final_answer}") |
| print(f"Iterations: {result.iterations}") |
|
|
| if ( |
| result.final_answer |
| and str(result.final_answer).strip().lower() == expected.strip().lower() |
| ): |
| print("CORRECT!") |
| else: |
| print("INCORRECT") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|