File size: 9,292 Bytes
362bbff
 
 
 
dce8bf6
 
362bbff
 
dce8bf6
 
 
362bbff
 
 
 
 
 
 
 
da3eda7
362bbff
 
 
 
 
 
 
 
 
 
 
 
c64fcea
da3eda7
 
 
dce8bf6
 
362bbff
 
 
 
dce8bf6
 
 
 
 
 
 
362bbff
 
329e3d3
 
362bbff
 
329e3d3
 
dce8bf6
329e3d3
dce8bf6
 
 
 
362bbff
 
329e3d3
 
 
 
 
 
dce8bf6
 
 
329e3d3
 
dce8bf6
 
362bbff
 
 
329e3d3
 
 
 
 
362bbff
 
 
dce8bf6
362bbff
 
 
 
 
da3eda7
 
 
dce8bf6
da3eda7
362bbff
da3eda7
dce8bf6
362bbff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dce8bf6
 
362bbff
 
 
 
dce8bf6
362bbff
dce8bf6
362bbff
 
 
 
 
ca40c95
 
 
 
 
 
 
 
362bbff
 
 
 
 
 
 
 
dce8bf6
362bbff
 
dce8bf6
 
362bbff
 
 
 
 
 
dce8bf6
362bbff
 
dce8bf6
362bbff
 
 
 
 
dce8bf6
362bbff
 
 
dce8bf6
362bbff
 
 
dce8bf6
362bbff
 
 
 
 
 
 
 
 
 
 
 
dce8bf6
362bbff
 
 
 
dce8bf6
 
362bbff
 
dce8bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362bbff
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
#!/usr/bin/env python3
"""
Baseline inference script for the Shopify Store Audit environment.

Runs an LLM agent against all 3 tasks (easy, medium, hard) and
emits structured [START]/[STEP]/[END] logs for each.

Required env vars:
    API_BASE_URL  β€” LLM endpoint
    MODEL_NAME    β€” Model identifier
    HF_TOKEN      β€” API key
"""

from __future__ import annotations

import asyncio
import json
import os
import sys
from typing import Any, Dict, List, Optional

from openai import OpenAI

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from client import ShopifyStoreAuditEnv
from models import ShopifyStoreAuditAction

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or ""
API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o"

IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME") or ""
ENV_URL = os.getenv("ENV_URL", "")

BENCHMARK = "shopify_store_audit"
TEMPERATURE = 0.2
MAX_TOKENS = 1024
SUCCESS_SCORE_THRESHOLD = 0.5

TASKS = [
    {"id": "product_listing_qa",        "max_steps": 25},
    {"id": "seo_collection_optimization", "max_steps": 35},
    {"id": "full_store_audit",          "max_steps": 50},
]

SYSTEM_PROMPT = """\
You are a Shopify store audit agent. You interact with a real product catalog \
loaded from Shopify CSV exports. Discover issues and fix them through API commands.

AVAILABLE COMMANDS (send as JSON):
  Query commands (investigate):
    query_store_health  β€” diagnostic overview (START HERE, detail varies by difficulty)
    query_products      β€” list products (params: status, search, product_type, limit)
    query_product       β€” full product detail (params: product_id)
    query_collections   β€” list collections
    query_collection    β€” collection detail (params: collection_id)
    query_inventory     β€” inventory levels (params: product_id, location_id)
    query_orders        β€” list orders (params: fulfillment_status)

  Fix commands (mutations):
    update_product      β€” update fields (params: product_id, description, status, tags)
    update_variant      β€” update variant (params: product_id, price, compare_at_price, sku)
    update_product_seo  β€” set SEO (params: product_id, seo_title, seo_description)
    update_image_alt_text β€” set alt text (params: product_id, alt_text)
    add_product_image   β€” add image (params: product_id, url, alt_text)
    update_collection   β€” update rules (params: collection_id, rules)
    add_product_to_collection β€” (params: collection_id, product_id)
    adjust_inventory    β€” set qty (params: product_id, location_id, quantity)
    update_metafield    β€” set metafield (params: product_id, key, value)
    publish_product     β€” activate (params: product_id)
    update_order        β€” fulfill (params: order_id, fulfillment_status)

RESPONSE FORMAT β€” reply with ONLY a JSON object, no markdown:
{"command": "<command_name>", "params": {<parameters>}}

STRATEGY:
1. Start with query_store_health for a diagnostic overview
2. For issues you understand, apply fixes directly
3. For unclear issues, use query_product or query_inventory to investigate first
4. Don't repeat the same action β€” if it didn't work, try something different
5. Be efficient β€” discovery and fixing both earn rewards, repetition is penalised
"""

# ---------------------------------------------------------------------------
# Logging (mandatory format)
# ---------------------------------------------------------------------------

def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)

def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    error_val = error if error else "null"
    done_val = str(done).lower()
    print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)

def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)

# ---------------------------------------------------------------------------
# LLM interaction
# ---------------------------------------------------------------------------

def parse_action(text: str) -> Dict[str, Any]:
    text = text.strip()
    if text.startswith("```"):
        lines = text.split("\n")
        lines = [l for l in lines if not l.strip().startswith("```")]
        text = "\n".join(lines)
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        start = text.find("{")
        end = text.rfind("}") + 1
        if start >= 0 and end > start:
            try:
                return json.loads(text[start:end])
            except json.JSONDecodeError:
                pass
    return {"command": "query_store_health", "params": {}}


def get_model_action(
    client: OpenAI, step: int, observation_msg: str,
    observation_data: dict, last_reward: float, history: List[str],
) -> Dict[str, Any]:
    context = (
        f"Step {step} | Last reward: {last_reward:+.3f}\n"
        f"Observation: {observation_msg}\n"
        f"Data: {json.dumps(observation_data, indent=2)[:3000]}"
    )
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    for h in history[-6:]:
        messages.append({"role": "user", "content": h})
    messages.append({"role": "user", "content": context})

    try:
        kwargs = dict(model=MODEL_NAME, messages=messages, temperature=TEMPERATURE, stream=False)
        try:
            kwargs["max_tokens"] = MAX_TOKENS
            completion = client.chat.completions.create(**kwargs)
        except Exception:
            kwargs.pop("max_tokens", None)
            kwargs["max_completion_tokens"] = MAX_TOKENS
            completion = client.chat.completions.create(**kwargs)
        text = (completion.choices[0].message.content or "").strip()
        return parse_action(text)
    except Exception as exc:
        print(f"[DEBUG] Model request failed: {exc}", flush=True)
        return {"command": "query_store_health", "params": {}}


# ---------------------------------------------------------------------------
# Single task runner
# ---------------------------------------------------------------------------

async def run_task(env, client: OpenAI, task_id: str, max_steps: int) -> float:
    """Run one task, emit [START]/[STEP]/[END], return score."""
    history: List[str] = []
    rewards: List[float] = []
    steps_taken = 0
    score = 0.0
    success = False

    log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)

    try:
        result = await env.reset(task=task_id)
        obs = result.observation
        last_msg = obs.message
        last_data = obs.data
        last_reward = 0.0

        for step in range(1, max_steps + 1):
            if result.done:
                break

            action_dict = get_model_action(client, step, last_msg, last_data, last_reward, history)
            command = action_dict.get("command", "query_store_health")
            params = action_dict.get("params", {})

            result = await env.step(ShopifyStoreAuditAction(command=command, params=params))
            obs = result.observation
            reward = result.reward or 0.0
            done = result.done

            rewards.append(reward)
            steps_taken = step
            last_msg = obs.message
            last_data = obs.data
            last_reward = reward

            action_str = json.dumps(action_dict)
            log_step(step=step, action=action_str, reward=reward, done=done, error=None)
            history.append(f"Step {step}: {action_str} -> {obs.message} (reward={reward:+.3f})")

            if done:
                break

        score = obs.store_health_score if obs else 0.0
        score = min(max(score, 0.0), 1.0)
        success = score >= SUCCESS_SCORE_THRESHOLD

    except Exception as e:
        print(f"[DEBUG] Task {task_id} error: {e}", flush=True)
    finally:
        log_end(success=success, steps=steps_taken, score=score, rewards=rewards)

    return score


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

async def main() -> None:
    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)

    if IMAGE_NAME:
        env = await ShopifyStoreAuditEnv.from_docker_image(IMAGE_NAME)
    elif ENV_URL:
        env = ShopifyStoreAuditEnv(base_url=ENV_URL)
        await env.connect()
    else:
        env = ShopifyStoreAuditEnv(base_url="http://localhost:8000")
        await env.connect()

    try:
        for task_cfg in TASKS:
            await run_task(env, client, task_cfg["id"], task_cfg["max_steps"])
    finally:
        try:
            await env.close()
        except Exception as e:
            print(f"[DEBUG] env.close() error: {e}", flush=True)


if __name__ == "__main__":
    asyncio.run(main())