File size: 5,099 Bytes
fd1ecb5
 
63dd587
 
fd1ecb5
 
63dd587
fd1ecb5
63dd587
 
 
 
 
 
 
 
fd1ecb5
 
 
 
 
 
63dd587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd1ecb5
63dd587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd1ecb5
63dd587
 
 
 
fd1ecb5
63dd587
 
 
 
 
 
fd1ecb5
63dd587
 
fd1ecb5
63dd587
fd1ecb5
63dd587
 
 
 
 
 
fd1ecb5
63dd587
 
fd1ecb5
63dd587
 
 
 
fd1ecb5
63dd587
 
fd1ecb5
63dd587
 
fd1ecb5
63dd587
fd1ecb5
63dd587
 
 
 
 
 
 
 
 
 
 
fd1ecb5
 
63dd587
 
 
 
 
 
fd1ecb5
63dd587
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python3
"""
Inference script for StructuralDesignEnv.
LLM agent (Claude) designs a steel building frame step by step.

Usage:
    python scripts/inference.py [task_id]

    task_id: task1_warehouse (default) | task2_office | task3_hospital

Environment variables:
    ENV_URL          — base URL of the running server (default: http://localhost:7860)
    INFERENCE_MODEL  — model name (default: claude-opus-4-6)
    ANTHROPIC_API_KEY or OPENAI_API_KEY
    OPENAI_BASE_URL  — override API base URL
"""

import json
import os
import sys

import httpx
from openai import OpenAI

BASE_URL = os.getenv("ENV_URL", "http://localhost:7860")
MODEL = os.getenv("INFERENCE_MODEL", "claude-opus-4-6")

SYSTEM_PROMPT = """You are a structural engineer designing a building frame step-by-step.
You place columns, beams, and shear walls on a building grid, then receive
physics analysis showing whether your design is structurally safe.

PHYSICS RULES:
- Beams carry vertical load via bending: M = w*L^2/8. Longer spans need bigger sections.
- Columns carry vertical load via compression. More floors = higher axial load.
- Lateral loads (wind/seismic) require lateral resistance: shear walls or moment frames.
- Utilization ratio (UR) = demand/capacity. Must be < 1.0 for all members.
- UR=1.47 means 47% overstressed → upgrade section or reduce span.
- Deflection limit: maximum beam deflection < span/300.
- Lateral drift limit: story drift < height/500.

DESIGN STRATEGY:
1. Establish column grid (spacing 4-6m gives economical spans)
2. Add beams in both directions
3. Check physics → upgrade any UR > 1.0 members
4. Add shear walls if lateral drift > limit
5. Downgrade members with UR < 0.6 (wasteful)
6. Signal "done" only when all URs < 1.0

Respond with a single JSON action object matching the StructuralAction schema.
Do not include any text outside the JSON object."""

client = OpenAI(
    base_url=os.getenv("OPENAI_BASE_URL", "https://api.anthropic.com/v1"),
    api_key=os.getenv("ANTHROPIC_API_KEY", os.getenv("OPENAI_API_KEY", "")),
)


def run_episode(task_id: str = "task1_warehouse"):
    env = httpx.Client(base_url=BASE_URL, timeout=60)

    # Reset
    resp = env.post("/reset", json={"task_id": task_id})
    resp.raise_for_status()
    data = resp.json()
    session_id = data["session_id"]
    obs = data["observation"]

    print(f"\n{'=' * 60}")
    print(f"Task: {task_id}  |  Session: {session_id}")
    print(f"{'=' * 60}")
    print(obs["message"])

    messages = [{"role": "user", "content": obs["message"]}]
    done = False
    total_reward = 0.0
    step = 0
    max_steps = obs.get("max_steps", 100)

    while not done and step < max_steps + 5:
        # Query LLM
        try:
            response = client.chat.completions.create(
                model=MODEL,
                messages=[{"role": "system", "content": SYSTEM_PROMPT}] + messages,
                max_tokens=512,
                temperature=0.0,
            )
            action_str = response.choices[0].message.content.strip()
        except Exception as e:
            print(f"\n[LLM error] {e}")
            break

        # Strip markdown code fences if present
        if action_str.startswith("```"):
            action_str = action_str.split("```")[1]
            if action_str.startswith("json"):
                action_str = action_str[4:]
            action_str = action_str.strip()

        print(f"\n[Step {step + 1}] Agent: {action_str}")
        messages.append({"role": "assistant", "content": action_str})

        # Step environment
        try:
            resp = env.post(
                "/step",
                json={"session_id": session_id, "message": action_str},
            )
            resp.raise_for_status()
            step_data = resp.json()
        except Exception as e:
            print(f"\n[HTTP error] {e}")
            break

        obs = step_data["observation"]
        reward = step_data["reward"]
        done = step_data["done"]
        info = step_data.get("info", {})

        total_reward += reward
        step += 1

        print(f"Reward: {reward:+.4f}  |  Total: {total_reward:+.4f}  |  Done: {done}")
        print(obs["message"])

        messages.append({"role": "user", "content": obs["message"]})

        if done:
            graded = info.get("graded_score", 0.0)
            print(f"\n{'=' * 60}")
            print(f"EPISODE COMPLETE")
            print(f"Steps: {step}  |  Total reward: {total_reward:.3f}  |  Score: {graded:.4f}")
            print(f"Valid: {obs.get('is_structurally_valid', False)}")
            print(f"Elements: {obs.get('n_elements_placed', 0)}")
            print(f"Steel mass: {obs.get('total_steel_mass_kg', 0):.0f} kg")
            print(f"{'=' * 60}\n")

    return total_reward


if __name__ == "__main__":
    task = sys.argv[1] if len(sys.argv) > 1 else "task1_warehouse"
    valid_tasks = {"task1_warehouse", "task2_office", "task3_hospital"}
    if task not in valid_tasks:
        print(f"Unknown task '{task}'. Valid: {sorted(valid_tasks)}")
        sys.exit(1)

    run_episode(task)