shopify-store-audit / inference.py
aatmk-panse
feat: environment redesign — real CSV data, shaped rewards, difficulty tiers
329e3d3
#!/usr/bin/env python3
"""
Baseline inference script for the Shopify Store Audit environment.
Runs an LLM agent against all 3 tasks (easy, medium, hard) and
emits structured [START]/[STEP]/[END] logs for each.
Required env vars:
API_BASE_URL — LLM endpoint
MODEL_NAME — Model identifier
HF_TOKEN — API key
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
from typing import Any, Dict, List, Optional
from openai import OpenAI
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from client import ShopifyStoreAuditEnv
from models import ShopifyStoreAuditAction
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or ""
API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o"
IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME") or ""
ENV_URL = os.getenv("ENV_URL", "")
BENCHMARK = "shopify_store_audit"
TEMPERATURE = 0.2
MAX_TOKENS = 1024
SUCCESS_SCORE_THRESHOLD = 0.5
TASKS = [
{"id": "product_listing_qa", "max_steps": 25},
{"id": "seo_collection_optimization", "max_steps": 35},
{"id": "full_store_audit", "max_steps": 50},
]
SYSTEM_PROMPT = """\
You are a Shopify store audit agent. You interact with a real product catalog \
loaded from Shopify CSV exports. Discover issues and fix them through API commands.
AVAILABLE COMMANDS (send as JSON):
Query commands (investigate):
query_store_health — diagnostic overview (START HERE, detail varies by difficulty)
query_products — list products (params: status, search, product_type, limit)
query_product — full product detail (params: product_id)
query_collections — list collections
query_collection — collection detail (params: collection_id)
query_inventory — inventory levels (params: product_id, location_id)
query_orders — list orders (params: fulfillment_status)
Fix commands (mutations):
update_product — update fields (params: product_id, description, status, tags)
update_variant — update variant (params: product_id, price, compare_at_price, sku)
update_product_seo — set SEO (params: product_id, seo_title, seo_description)
update_image_alt_text — set alt text (params: product_id, alt_text)
add_product_image — add image (params: product_id, url, alt_text)
update_collection — update rules (params: collection_id, rules)
add_product_to_collection — (params: collection_id, product_id)
adjust_inventory — set qty (params: product_id, location_id, quantity)
update_metafield — set metafield (params: product_id, key, value)
publish_product — activate (params: product_id)
update_order — fulfill (params: order_id, fulfillment_status)
RESPONSE FORMAT — reply with ONLY a JSON object, no markdown:
{"command": "<command_name>", "params": {<parameters>}}
STRATEGY:
1. Start with query_store_health for a diagnostic overview
2. For issues you understand, apply fixes directly
3. For unclear issues, use query_product or query_inventory to investigate first
4. Don't repeat the same action — if it didn't work, try something different
5. Be efficient — discovery and fixing both earn rewards, repetition is penalised
"""
# ---------------------------------------------------------------------------
# Logging (mandatory format)
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
# ---------------------------------------------------------------------------
# LLM interaction
# ---------------------------------------------------------------------------
def parse_action(text: str) -> Dict[str, Any]:
text = text.strip()
if text.startswith("```"):
lines = text.split("\n")
lines = [l for l in lines if not l.strip().startswith("```")]
text = "\n".join(lines)
try:
return json.loads(text)
except json.JSONDecodeError:
start = text.find("{")
end = text.rfind("}") + 1
if start >= 0 and end > start:
try:
return json.loads(text[start:end])
except json.JSONDecodeError:
pass
return {"command": "query_store_health", "params": {}}
def get_model_action(
client: OpenAI, step: int, observation_msg: str,
observation_data: dict, last_reward: float, history: List[str],
) -> Dict[str, Any]:
context = (
f"Step {step} | Last reward: {last_reward:+.3f}\n"
f"Observation: {observation_msg}\n"
f"Data: {json.dumps(observation_data, indent=2)[:3000]}"
)
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for h in history[-6:]:
messages.append({"role": "user", "content": h})
messages.append({"role": "user", "content": context})
try:
kwargs = dict(model=MODEL_NAME, messages=messages, temperature=TEMPERATURE, stream=False)
try:
kwargs["max_tokens"] = MAX_TOKENS
completion = client.chat.completions.create(**kwargs)
except Exception:
kwargs.pop("max_tokens", None)
kwargs["max_completion_tokens"] = MAX_TOKENS
completion = client.chat.completions.create(**kwargs)
text = (completion.choices[0].message.content or "").strip()
return parse_action(text)
except Exception as exc:
print(f"[DEBUG] Model request failed: {exc}", flush=True)
return {"command": "query_store_health", "params": {}}
# ---------------------------------------------------------------------------
# Single task runner
# ---------------------------------------------------------------------------
async def run_task(env, client: OpenAI, task_id: str, max_steps: int) -> float:
"""Run one task, emit [START]/[STEP]/[END], return score."""
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
result = await env.reset(task=task_id)
obs = result.observation
last_msg = obs.message
last_data = obs.data
last_reward = 0.0
for step in range(1, max_steps + 1):
if result.done:
break
action_dict = get_model_action(client, step, last_msg, last_data, last_reward, history)
command = action_dict.get("command", "query_store_health")
params = action_dict.get("params", {})
result = await env.step(ShopifyStoreAuditAction(command=command, params=params))
obs = result.observation
reward = result.reward or 0.0
done = result.done
rewards.append(reward)
steps_taken = step
last_msg = obs.message
last_data = obs.data
last_reward = reward
action_str = json.dumps(action_dict)
log_step(step=step, action=action_str, reward=reward, done=done, error=None)
history.append(f"Step {step}: {action_str} -> {obs.message} (reward={reward:+.3f})")
if done:
break
score = obs.store_health_score if obs else 0.0
score = min(max(score, 0.0), 1.0)
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as e:
print(f"[DEBUG] Task {task_id} error: {e}", flush=True)
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return score
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
async def main() -> None:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
if IMAGE_NAME:
env = await ShopifyStoreAuditEnv.from_docker_image(IMAGE_NAME)
elif ENV_URL:
env = ShopifyStoreAuditEnv(base_url=ENV_URL)
await env.connect()
else:
env = ShopifyStoreAuditEnv(base_url="http://localhost:8000")
await env.connect()
try:
for task_cfg in TASKS:
await run_task(env, client, task_cfg["id"], task_cfg["max_steps"])
finally:
try:
await env.close()
except Exception as e:
print(f"[DEBUG] env.close() error: {e}", flush=True)
if __name__ == "__main__":
asyncio.run(main())