File size: 6,027 Bytes
519736d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52eb44f
 
 
de442f8
52eb44f
519736d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205dc3f
 
519736d
 
 
 
 
 
 
70022c4
519736d
 
205dc3f
519736d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205dc3f
519736d
 
 
 
 
70022c4
519736d
 
 
 
 
 
 
 
 
70022c4
519736d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Inference Script β€” SocraticEnv
================================
MANDATORY variables (set in environment before running):
  API_BASE_URL  β€” The API endpoint for the LLM
  MODEL_NAME    β€” The model identifier to use
  HF_TOKEN      β€” Your HuggingFace token (used as API key)

Run:
  python inference.py
"""

import os
import time
import requests
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()

# ── Config ────────────────────────────────────────────────
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/novita/v3/openai")
MODEL_NAME   = os.getenv("MODEL_NAME",   "meta-llama/llama-3.1-8b-instruct")
HF_TOKEN     = os.getenv("HF_TOKEN")
ENV_URL      = os.getenv("ENV_URL", "http://localhost:7860")


MAX_TURNS    = 10
TEMPERATURE  = 0.3

client = OpenAI(
    base_url=API_BASE_URL,
    api_key=HF_TOKEN,
)

TASKS = ["factual_recall", "socratic_dialogue", "misconception_trap"]

SYSTEM_PROMPT = """You are an intelligent student in a Socratic dialogue with a tutor.
Your goals:
1. Answer questions clearly and accurately using correct terminology.
2. Show your reasoning β€” explain WHY, not just WHAT.
3. Be alert: if the tutor states something FALSE or misleading, 
   you must confidently disagree and explain the correct answer.
4. Stay engaged and thoughtful throughout the conversation.
Keep responses focused and between 3-6 sentences."""


def call_llm(messages: list) -> str:
    """Call the LLM and return its response text."""
    try:
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=messages,
            max_tokens=300,
            temperature=TEMPERATURE,
        )
        return completion.choices[0].message.content.strip()
    except Exception as e:
        print(f"  [LLM ERROR] {e}")
        return "I need to think about that more carefully before responding."


def reset_env(task_id: str) -> dict:
    r = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
    r.raise_for_status()
    return r.json()


def step_env(response: str, session_id: str) -> dict:
    r = requests.post(f"{ENV_URL}/step", json={"response": response, "session_id": session_id})
    r.raise_for_status()
    return r.json()


def run_task(task_id: str) -> dict:
    """Run one full episode of a task and return results."""
    print(f"\n── Task: {task_id} ─────────────────────────────────")
    print(f"[START] task={task_id}", flush=True)

    reset_data = reset_env(task_id)
    session_id = reset_data["session_id"]
    obs = reset_data["observation"]

    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    total_score = 0.0
    turns = 0

    print(f"  Tutor: {obs['question'][:100]}...")

    for _ in range(MAX_TURNS):
        # Add tutor question to messages
        messages.append({"role": "user", "content": obs["question"]})

        # Get agent response from LLM
        agent_response = call_llm(messages)
        messages.append({"role": "assistant", "content": agent_response})

        print(f"  Agent (turn {turns+1}): {agent_response[:80]}...")

        # Step the environment
        result = step_env(agent_response, session_id)
        reward = result["reward"]["score"]
        total_score += reward
        turns += 1

        print(f"  Reward: {reward:.3f} | Breakdown: {result['reward']['breakdown']}")
        print(f"[STEP] step={turns} reward={reward}", flush=True)

        if result["done"]:
            break

        obs = result["observation"]
        time.sleep(0.5)  # be gentle with the API

    final_score = round(min(total_score / max(turns, 1), 1.0), 3)
    print(f"  ── Final Score: {final_score} ({'PASS' if final_score >= 0.5 else 'FAIL'})")
    print(f"[END] task={task_id} score={final_score} steps={turns}", flush=True)

    return {
        "task": task_id,
        "score": final_score,
        "turns": turns,
        "passed": final_score >= 0.5,
    }


def main():
    print("\n════════════════════════════════════════════")
    print("  SocraticEnv β€” Baseline Inference Script")
    print("════════════════════════════════════════════")
    print(f"  Model:   {MODEL_NAME}")
    print(f"  Env URL: {ENV_URL}")
    print("════════════════════════════════════════════")

    # Check env is up
    try:
        r = requests.get(f"{ENV_URL}/ping")
        r.raise_for_status()
        print("  Env: ONLINE βœ“")
    except Exception:
        print("  ERROR: Environment is not running!")
        print("  Start it first with: python main.py")
        return

    results = {}
    for task_id in TASKS:
        results[task_id] = run_task(task_id)
        time.sleep(1)

    # Summary
    print("\n════════════════════════════════════════════")
    print("  RESULTS SUMMARY")
    print("════════════════════════════════════════════")
    all_scores = []
    for task_id, r in results.items():
        status = "βœ“ PASS" if r["passed"] else "βœ— FAIL"
        print(f"  {status} | {task_id:<25} | Score: {r['score']:.3f}")
        all_scores.append(r["score"])

    overall = round(sum(all_scores) / len(all_scores), 3)
    print(f"\n  Overall Score: {overall:.3f}")
    print(f"  All Passed:   {all(r['passed'] for r in results.values())}")
    print("════════════════════════════════════════════\n")


if __name__ == "__main__":
    main()