junaid0600 commited on
Commit
82d2df4
Β·
1 Parent(s): 7810e7a

Fix [END] format - remove score field, add HF_TOKEN validation

Browse files
Files changed (1) hide show
  1. inference.py +6 -20
inference.py CHANGED
@@ -14,40 +14,26 @@ from dotenv import load_dotenv
14
  load_dotenv()
15
 
16
  from env.environment import SQLDebuggerEnvironment
17
- from env.models import Action, ActionType, DifficultyLevel
18
-
19
  # ─────────────────────────────────────────────
20
  # ENVIRONMENT VARIABLES
21
  # ─────────────────────────────────────────────
22
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or "dummy-key"
23
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
24
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
25
- BENCHMARK = "sql-query-debugger"
26
- MAX_STEPS = 10
27
- SUCCESS_SCORE_THRESHOLD = 0.5
28
 
29
- # ─────────────────────────────────────────────
30
- # LOGGING FUNCTIONS β€” exact format required
 
 
31
  # ─────────────────────────────────────────────
32
 
33
  def log_start(task: str, env: str, model: str) -> None:
34
  print(f"[START] task={task} env={env} model={model}", flush=True)
35
 
36
 
37
- def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
38
- error_val = error if error else "null"
39
- done_val = str(done).lower()
40
- print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
41
-
42
-
43
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
44
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
45
- print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
46
-
47
-
48
- # ─────────────────────────────────────────────
49
- # SYSTEM PROMPT
50
- # ─────────────────────────────────────────────
51
 
52
  SYSTEM_PROMPT = textwrap.dedent("""
53
  You are an expert SQL debugger. You will be given a buggy SQL query and must fix it.
 
14
  load_dotenv()
15
 
16
  from env.environment import SQLDebuggerEnvironment
 
 
17
  # ─────────────────────────────────────────────
18
  # ENVIRONMENT VARIABLES
19
  # ─────────────────────────────────────────────
 
20
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
21
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
22
+ HF_TOKEN = os.getenv("HF_TOKEN")
 
 
23
 
24
+ if HF_TOKEN is None:
25
+ raise ValueError("HF_TOKEN environment variable is required")
26
+
27
+ API_KEY = HF_TOKEN
28
  # ─────────────────────────────────────────────
29
 
30
  def log_start(task: str, env: str, model: str) -> None:
31
  print(f"[START] task={task} env={env} model={model}", flush=True)
32
 
33
 
 
 
 
 
 
 
34
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
35
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
36
+ print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", flush=True)
 
 
 
 
 
37
 
38
  SYSTEM_PROMPT = textwrap.dedent("""
39
  You are an expert SQL debugger. You will be given a buggy SQL query and must fix it.