ritvik360 commited on
Commit
1a45976
Β·
verified Β·
1 Parent(s): ed2e608

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .env.example +1 -1
  2. client.py +8 -11
  3. inference.py +29 -11
.env.example CHANGED
@@ -10,7 +10,7 @@
10
  API_BASE_URL=https://router.huggingface.co/v1
11
 
12
  # Model identifier β€” must be accessible at the above endpoint
13
- MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
14
 
15
  # HuggingFace API token (also used as the OpenAI-client api_key)
16
  HF_TOKEN=hf_your_token_here
 
10
  API_BASE_URL=https://router.huggingface.co/v1
11
 
12
  # Model identifier β€” must be accessible at the above endpoint
13
+ MODEL_NAME=Qwen/Qwen2.5-7B-Instruct
14
 
15
  # HuggingFace API token (also used as the OpenAI-client api_key)
16
  HF_TOKEN=hf_your_token_here
client.py CHANGED
@@ -51,18 +51,15 @@ class NL2SQLEnv:
51
  return self._parse_result(resp.json())
52
 
53
  def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
54
- obs_data = payload.get("observation", {})
55
 
56
- # SAFETY CHECK: Handle JSON 'null' (None) values gracefully
57
- raw_reward = payload.get("reward")
58
- safe_reward = float(raw_reward) if raw_reward is not None else 0.0
59
-
60
- raw_obs_reward = obs_data.get("reward")
61
- safe_obs_reward = float(raw_obs_reward) if raw_obs_reward is not None else 0.0
62
-
63
- raw_score = obs_data.get("score")
64
- safe_score = float(raw_score) if raw_score is not None else 0.0
65
 
 
66
  safe_done = bool(payload.get("done") or obs_data.get("done") or False)
67
 
68
  obs = NL2SQLObservation(
@@ -76,7 +73,7 @@ class NL2SQLEnv:
76
  step=obs_data.get("step", 0),
77
  max_steps=obs_data.get("max_steps", 5),
78
  done=safe_done,
79
- reward=safe_obs_reward,
80
  score=safe_score,
81
  )
82
  return StepResult(
 
51
  return self._parse_result(resp.json())
52
 
53
  def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
54
+ obs_data = payload.get("observation", payload)
55
 
56
+ # ── THE BULLETPROOF REWARD EXTRACTOR ──
57
+ # Check both the top-level payload and the nested observation dict.
58
+ val1 = float(payload.get("reward") or 0.0)
59
+ val2 = float(obs_data.get("reward") or 0.0)
60
+ safe_reward = max(val1, val2)
 
 
 
 
61
 
62
+ safe_score = float(obs_data.get("score") or 0.0)
63
  safe_done = bool(payload.get("done") or obs_data.get("done") or False)
64
 
65
  obs = NL2SQLObservation(
 
73
  step=obs_data.get("step", 0),
74
  max_steps=obs_data.get("max_steps", 5),
75
  done=safe_done,
76
+ reward=safe_reward,
77
  score=safe_score,
78
  )
79
  return StepResult(
inference.py CHANGED
@@ -27,18 +27,36 @@ from typing import List, Optional
27
 
28
  from openai import OpenAI
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # ── Configuration ──────────────────────────────────────────────────────────
31
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
32
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
33
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
34
- IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
35
- SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000")
36
-
37
- BENCHMARK = "nl2sql-bench"
38
- MAX_STEPS = 5
39
- TEMPERATURE = 0.2 # Low temp for SQL generation
40
- MAX_TOKENS = 512
41
- SUCCESS_THRESHOLD = 0.7 # score >= 0.7 β†’ success
 
 
 
42
 
43
  TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
44
 
 
27
 
28
  from openai import OpenAI
29
 
30
+ # # ── Configuration ──────────────────────────────────────────────────────────
31
+ # API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
32
+ # MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
33
+ # API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
34
+ # IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
35
+ # SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000")
36
+
37
+ # BENCHMARK = "nl2sql-bench"
38
+ # MAX_STEPS = 5
39
+ # TEMPERATURE = 0.2 # Low temp for SQL generation
40
+ # MAX_TOKENS = 512
41
+ # SUCCESS_THRESHOLD = 0.7 # score >= 0.7 β†’ success
42
+
43
+ # TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
44
+
45
  # ── Configuration ──────────────────────────────────────────────────────────
46
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
47
+ # Points to your newly uploaded fine-tuned weights!
48
+ MODEL_NAME = os.getenv("MODEL_NAME", "ritvik360/qwen-7b-nl2sql-merged_1")
49
+ # CRITICAL FIX: Looks for 'API_KEY' first to satisfy the evaluator's LiteLLM proxy
50
+ API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN", "") or os.getenv("OPENAI_API_KEY")
51
+ IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
52
+ # CRITICAL FIX: Point the default directly to your live HF Space!
53
+ SPACE_URL = os.getenv("SPACE_URL", "https://ritvik360-nl2sql-bench.hf.space")
54
+
55
+ BENCHMARK = "nl2sql-bench"
56
+ MAX_STEPS = 5
57
+ TEMPERATURE = 0.2 # Low temp for SQL generation
58
+ MAX_TOKENS = 512
59
+ SUCCESS_THRESHOLD = 0.7 # score >= 0.7 β†’ success
60
 
61
  TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
62