Spaces:
Sleeping
Sleeping
File size: 5,099 Bytes
fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 fd1ecb5 63dd587 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | #!/usr/bin/env python3
"""
Inference script for StructuralDesignEnv.
LLM agent (Claude) designs a steel building frame step by step.
Usage:
python scripts/inference.py [task_id]
task_id: task1_warehouse (default) | task2_office | task3_hospital
Environment variables:
ENV_URL — base URL of the running server (default: http://localhost:7860)
INFERENCE_MODEL — model name (default: claude-opus-4-6)
ANTHROPIC_API_KEY or OPENAI_API_KEY
OPENAI_BASE_URL — override API base URL
"""
import json
import os
import sys
import httpx
from openai import OpenAI
BASE_URL = os.getenv("ENV_URL", "http://localhost:7860")
MODEL = os.getenv("INFERENCE_MODEL", "claude-opus-4-6")
SYSTEM_PROMPT = """You are a structural engineer designing a building frame step-by-step.
You place columns, beams, and shear walls on a building grid, then receive
physics analysis showing whether your design is structurally safe.
PHYSICS RULES:
- Beams carry vertical load via bending: M = w*L^2/8. Longer spans need bigger sections.
- Columns carry vertical load via compression. More floors = higher axial load.
- Lateral loads (wind/seismic) require lateral resistance: shear walls or moment frames.
- Utilization ratio (UR) = demand/capacity. Must be < 1.0 for all members.
- UR=1.47 means 47% overstressed → upgrade section or reduce span.
- Deflection limit: maximum beam deflection < span/300.
- Lateral drift limit: story drift < height/500.
DESIGN STRATEGY:
1. Establish column grid (spacing 4-6m gives economical spans)
2. Add beams in both directions
3. Check physics → upgrade any UR > 1.0 members
4. Add shear walls if lateral drift > limit
5. Downgrade members with UR < 0.6 (wasteful)
6. Signal "done" only when all URs < 1.0
Respond with a single JSON action object matching the StructuralAction schema.
Do not include any text outside the JSON object."""
client = OpenAI(
base_url=os.getenv("OPENAI_BASE_URL", "https://api.anthropic.com/v1"),
api_key=os.getenv("ANTHROPIC_API_KEY", os.getenv("OPENAI_API_KEY", "")),
)
def run_episode(task_id: str = "task1_warehouse"):
env = httpx.Client(base_url=BASE_URL, timeout=60)
# Reset
resp = env.post("/reset", json={"task_id": task_id})
resp.raise_for_status()
data = resp.json()
session_id = data["session_id"]
obs = data["observation"]
print(f"\n{'=' * 60}")
print(f"Task: {task_id} | Session: {session_id}")
print(f"{'=' * 60}")
print(obs["message"])
messages = [{"role": "user", "content": obs["message"]}]
done = False
total_reward = 0.0
step = 0
max_steps = obs.get("max_steps", 100)
while not done and step < max_steps + 5:
# Query LLM
try:
response = client.chat.completions.create(
model=MODEL,
messages=[{"role": "system", "content": SYSTEM_PROMPT}] + messages,
max_tokens=512,
temperature=0.0,
)
action_str = response.choices[0].message.content.strip()
except Exception as e:
print(f"\n[LLM error] {e}")
break
# Strip markdown code fences if present
if action_str.startswith("```"):
action_str = action_str.split("```")[1]
if action_str.startswith("json"):
action_str = action_str[4:]
action_str = action_str.strip()
print(f"\n[Step {step + 1}] Agent: {action_str}")
messages.append({"role": "assistant", "content": action_str})
# Step environment
try:
resp = env.post(
"/step",
json={"session_id": session_id, "message": action_str},
)
resp.raise_for_status()
step_data = resp.json()
except Exception as e:
print(f"\n[HTTP error] {e}")
break
obs = step_data["observation"]
reward = step_data["reward"]
done = step_data["done"]
info = step_data.get("info", {})
total_reward += reward
step += 1
print(f"Reward: {reward:+.4f} | Total: {total_reward:+.4f} | Done: {done}")
print(obs["message"])
messages.append({"role": "user", "content": obs["message"]})
if done:
graded = info.get("graded_score", 0.0)
print(f"\n{'=' * 60}")
print(f"EPISODE COMPLETE")
print(f"Steps: {step} | Total reward: {total_reward:.3f} | Score: {graded:.4f}")
print(f"Valid: {obs.get('is_structurally_valid', False)}")
print(f"Elements: {obs.get('n_elements_placed', 0)}")
print(f"Steel mass: {obs.get('total_steel_mass_kg', 0):.0f} kg")
print(f"{'=' * 60}\n")
return total_reward
if __name__ == "__main__":
task = sys.argv[1] if len(sys.argv) > 1 else "task1_warehouse"
valid_tasks = {"task1_warehouse", "task2_office", "task3_hospital"}
if task not in valid_tasks:
print(f"Unknown task '{task}'. Valid: {sorted(valid_tasks)}")
sys.exit(1)
run_episode(task)
|