Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .env.example +1 -1
- client.py +8 -11
- inference.py +29 -11
.env.example
CHANGED
|
@@ -10,7 +10,7 @@
|
|
| 10 |
API_BASE_URL=https://router.huggingface.co/v1
|
| 11 |
|
| 12 |
# Model identifier β must be accessible at the above endpoint
|
| 13 |
-
MODEL_NAME=Qwen/Qwen2.5-
|
| 14 |
|
| 15 |
# HuggingFace API token (also used as the OpenAI-client api_key)
|
| 16 |
HF_TOKEN=hf_your_token_here
|
|
|
|
| 10 |
API_BASE_URL=https://router.huggingface.co/v1
|
| 11 |
|
| 12 |
# Model identifier β must be accessible at the above endpoint
|
| 13 |
+
MODEL_NAME=Qwen/Qwen2.5-7B-Instruct
|
| 14 |
|
| 15 |
# HuggingFace API token (also used as the OpenAI-client api_key)
|
| 16 |
HF_TOKEN=hf_your_token_here
|
client.py
CHANGED
|
@@ -51,18 +51,15 @@ class NL2SQLEnv:
|
|
| 51 |
return self._parse_result(resp.json())
|
| 52 |
|
| 53 |
def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
|
| 54 |
-
obs_data = payload.get("observation",
|
| 55 |
|
| 56 |
-
#
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
safe_obs_reward = float(raw_obs_reward) if raw_obs_reward is not None else 0.0
|
| 62 |
-
|
| 63 |
-
raw_score = obs_data.get("score")
|
| 64 |
-
safe_score = float(raw_score) if raw_score is not None else 0.0
|
| 65 |
|
|
|
|
| 66 |
safe_done = bool(payload.get("done") or obs_data.get("done") or False)
|
| 67 |
|
| 68 |
obs = NL2SQLObservation(
|
|
@@ -76,7 +73,7 @@ class NL2SQLEnv:
|
|
| 76 |
step=obs_data.get("step", 0),
|
| 77 |
max_steps=obs_data.get("max_steps", 5),
|
| 78 |
done=safe_done,
|
| 79 |
-
reward=
|
| 80 |
score=safe_score,
|
| 81 |
)
|
| 82 |
return StepResult(
|
|
|
|
| 51 |
return self._parse_result(resp.json())
|
| 52 |
|
| 53 |
def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
|
| 54 |
+
obs_data = payload.get("observation", payload)
|
| 55 |
|
| 56 |
+
# ββ THE BULLETPROOF REWARD EXTRACTOR ββ
|
| 57 |
+
# Check both the top-level payload and the nested observation dict.
|
| 58 |
+
val1 = float(payload.get("reward") or 0.0)
|
| 59 |
+
val2 = float(obs_data.get("reward") or 0.0)
|
| 60 |
+
safe_reward = max(val1, val2)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
safe_score = float(obs_data.get("score") or 0.0)
|
| 63 |
safe_done = bool(payload.get("done") or obs_data.get("done") or False)
|
| 64 |
|
| 65 |
obs = NL2SQLObservation(
|
|
|
|
| 73 |
step=obs_data.get("step", 0),
|
| 74 |
max_steps=obs_data.get("max_steps", 5),
|
| 75 |
done=safe_done,
|
| 76 |
+
reward=safe_reward,
|
| 77 |
score=safe_score,
|
| 78 |
)
|
| 79 |
return StepResult(
|
inference.py
CHANGED
|
@@ -27,18 +27,36 @@ from typing import List, Optional
|
|
| 27 |
|
| 28 |
from openai import OpenAI
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
-
API_BASE_URL
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
|
| 44 |
|
|
|
|
| 27 |
|
| 28 |
from openai import OpenAI
|
| 29 |
|
| 30 |
+
# # ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
# API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 32 |
+
# MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
|
| 33 |
+
# API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
|
| 34 |
+
# IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
|
| 35 |
+
# SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000")
|
| 36 |
+
|
| 37 |
+
# BENCHMARK = "nl2sql-bench"
|
| 38 |
+
# MAX_STEPS = 5
|
| 39 |
+
# TEMPERATURE = 0.2 # Low temp for SQL generation
|
| 40 |
+
# MAX_TOKENS = 512
|
| 41 |
+
# SUCCESS_THRESHOLD = 0.7 # score >= 0.7 β success
|
| 42 |
+
|
| 43 |
+
# TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
|
| 44 |
+
|
| 45 |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 47 |
+
# Points to your newly uploaded fine-tuned weights!
|
| 48 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "ritvik360/qwen-7b-nl2sql-merged_1")
|
| 49 |
+
# CRITICAL FIX: Looks for 'API_KEY' first to satisfy the evaluator's LiteLLM proxy
|
| 50 |
+
API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN", "") or os.getenv("OPENAI_API_KEY")
|
| 51 |
+
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
|
| 52 |
+
# CRITICAL FIX: Point the default directly to your live HF Space!
|
| 53 |
+
SPACE_URL = os.getenv("SPACE_URL", "https://ritvik360-nl2sql-bench.hf.space")
|
| 54 |
+
|
| 55 |
+
BENCHMARK = "nl2sql-bench"
|
| 56 |
+
MAX_STEPS = 5
|
| 57 |
+
TEMPERATURE = 0.2 # Low temp for SQL generation
|
| 58 |
+
MAX_TOKENS = 512
|
| 59 |
+
SUCCESS_THRESHOLD = 0.7 # score >= 0.7 β success
|
| 60 |
|
| 61 |
TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
|
| 62 |
|