File size: 8,601 Bytes
325052f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""Baseline inference script for the Scheduling Optimisation Environment.



Runs GPT-4o-mini (or falls back to deterministic mock responses) against all

three tasks and prints a structured score report.



Usage:

    OPENAI_API_KEY=sk-... python baseline.py

"""

from __future__ import annotations

import json
import os
import sys
from typing import Any

from environment import INSTANCE_BANK
from graders.grader_classification import ConflictGrader
from graders.grader_detection import FeasibilityGrader
from graders.grader_fix import RepairGrader
from models import Action


def _get_openai_client():
    """Return an OpenAI client, or None if unavailable."""
    api_key = os.environ.get("OPENAI_API_KEY", "")
    if not api_key:
        return None
    try:
        from openai import OpenAI
        return OpenAI(api_key=api_key)
    except Exception:
        return None


def _llm_response(client, system_prompt: str, user_prompt: str) -> str:
    """Call GPT-4o-mini and return the response text."""
    try:
        resp = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            max_tokens=1024,
            temperature=0.0,
        )
        return resp.choices[0].message.content.strip()
    except Exception as e:
        print(f"  [LLM error: {e}]")
        return ""


# ---------------------------------------------------------------------------
# Mock fallback responses (used when no API key is available)
# ---------------------------------------------------------------------------

# Ground-truth feasibility labels — index aligns with INSTANCE_BANK
_MOCK_FEASIBILITY: dict[int, str] = {
    0: "infeasible", 1: "infeasible", 2: "infeasible", 3: "infeasible",
    4: "infeasible", 5: "infeasible", 6: "infeasible", 7: "infeasible",
    8: "infeasible", 9: "infeasible", 10: "feasible", 11: "feasible",
}

# Ground-truth violation types for infeasible instances
_MOCK_CLASSIFICATION: dict[int, str] = {
    0: "resource_overload",
    1: "deadline_violation",
    2: "precedence_violation",
    3: "availability_conflict",
    4: "capacity_exceeded",
    5: "resource_overload",
    6: "deadline_violation",
    7: "precedence_violation",
    8: "availability_conflict",
    9: "capacity_exceeded",
}


def _mock_repair(instance_idx: int) -> str:
    """Return the known optimal schedule JSON for mock mode."""
    entry = INSTANCE_BANK[instance_idx]
    optimal = entry.get("optimal_schedule", {})
    if not optimal:
        # Return the proposed schedule unchanged as a safe fallback
        optimal = entry["instance"].get("proposed_schedule", {})
    return json.dumps(optimal)


# ---------------------------------------------------------------------------
# Baseline runner
# ---------------------------------------------------------------------------


def run_baseline() -> dict[str, Any]:
    """Execute the baseline across all three tasks and return scores."""
    client = _get_openai_client()
    use_llm = client is not None
    mode = "GPT-4o-mini" if use_llm else "mock (no API key — oracle responses)"
    print(f"\n{'='*65}")
    print(f"  SchedulingOptEnv — Baseline Evaluation ({mode})")
    print(f"{'='*65}\n")

    results: dict[str, Any] = {"mode": mode, "tasks": {}}

    # ----- Task 1: Feasibility Check -----
    feas_grader = FeasibilityGrader()
    feas_scores: list[float] = []
    print("Task 1: Feasibility Check (easy)")
    for i, entry in enumerate(INSTANCE_BANK):
        instance_str = json.dumps(entry["instance"], indent=2)
        if use_llm:
            resp = _llm_response(
                client,
                (
                    "You are a scheduling expert. Determine if the proposed schedule "
                    "satisfies all constraints. Reply with ONLY 'feasible' or 'infeasible'."
                ),
                instance_str,
            )
        else:
            resp = _MOCK_FEASIBILITY.get(i, "infeasible")
        action = Action(response=resp, task_id="feasibility_check")
        score = feas_grader.grade(action, entry)
        feas_scores.append(score)
        status = "CORRECT" if score >= 0.95 else "wrong"
        expected = "feasible" if entry["is_feasible"] else "infeasible"
        print(
            f"  Instance {i:2d}: {status:7s} (score={score:.2f})  "
            f"expected={expected}  [{entry['description'][:45]}]"
        )

    avg_feas = sum(feas_scores) / len(feas_scores) if feas_scores else 0.0
    results["tasks"]["feasibility_check"] = {
        "average_score": round(avg_feas, 4),
        "num_instances": len(feas_scores),
        "scores": feas_scores,
    }
    print(f"  >> Average: {avg_feas:.3f}\n")

    # ----- Task 2: Conflict Classification -----
    conf_grader = ConflictGrader()
    conf_scores: list[float] = []
    infeasible_entries = [(i, e) for i, e in enumerate(INSTANCE_BANK) if not e["is_feasible"]]
    print("Task 2: Conflict Classification (medium)")
    for i, entry in infeasible_entries:
        instance_str = json.dumps(entry["instance"], indent=2)
        if use_llm:
            resp = _llm_response(
                client,
                (
                    "You are a scheduling expert. Identify the constraint violation type. "
                    "Reply with ONLY one of: resource_overload, deadline_violation, "
                    "precedence_violation, availability_conflict, capacity_exceeded."
                ),
                instance_str,
            )
        else:
            resp = _MOCK_CLASSIFICATION.get(i, "resource_overload")
        action = Action(response=resp, task_id="conflict_classification")
        score = conf_grader.grade(action, entry)
        conf_scores.append(score)
        status = "EXACT" if score >= 0.95 else ("partial" if score >= 0.45 else "wrong")
        print(
            f"  Instance {i:2d}: {status:7s} (score={score:.2f})  "
            f"expected={entry['violation_type']}"
        )

    avg_conf = sum(conf_scores) / len(conf_scores) if conf_scores else 0.0
    results["tasks"]["conflict_classification"] = {
        "average_score": round(avg_conf, 4),
        "num_instances": len(conf_scores),
        "scores": conf_scores,
    }
    print(f"  >> Average: {avg_conf:.3f}\n")

    # ----- Task 3: Schedule Repair -----
    repair_grader = RepairGrader()
    repair_scores: list[float] = []
    repairable = [
        (i, e) for i, e in enumerate(INSTANCE_BANK)
        if not e["is_feasible"] and e.get("optimal_schedule")
    ]
    print("Task 3: Schedule Repair (hard)")
    for i, entry in repairable:
        instance_str = json.dumps(entry["instance"], indent=2)
        if use_llm:
            resp = _llm_response(
                client,
                (
                    "You are a scheduling expert. Repair the infeasible schedule by "
                    "returning a JSON object with key 'assignments': a list of "
                    '{"job_id", "machine_id", "start_time"} dicts that satisfies all '
                    "constraints and minimises makespan. Return ONLY valid JSON."
                ),
                instance_str,
            )
        else:
            resp = _mock_repair(i)
        action = Action(response=resp, task_id="schedule_repair")
        score = repair_grader.grade(action, entry)
        repair_scores.append(score)
        print(
            f"  Instance {i:2d}: score={score:.2f}  "
            f"optimal_makespan={entry['optimal_makespan']}  "
            f"[{entry['description'][:45]}]"
        )

    avg_repair = sum(repair_scores) / len(repair_scores) if repair_scores else 0.0
    results["tasks"]["schedule_repair"] = {
        "average_score": round(avg_repair, 4),
        "num_instances": len(repair_scores),
        "scores": repair_scores,
    }
    print(f"  >> Average: {avg_repair:.3f}\n")

    # ----- Summary -----
    overall = (avg_feas + avg_conf + avg_repair) / 3
    results["overall_average"] = round(overall, 4)
    print(f"{'='*65}")
    print(f"  Overall Average Score: {overall:.3f}")
    print(f"{'='*65}\n")

    return results


if __name__ == "__main__":
    try:
        run_baseline()
    except Exception as e:
        print(f"Baseline failed: {e}", file=sys.stderr)
        sys.exit(1)