Spaces:
Sleeping
Sleeping
Sibam
refactor: apply production readiness recommendations including dataset caching, XSS protection, pure schemas, and JSON decoding logic.
5ee1380 | """ | |
| PreferenceLab Baseline Inference Script | |
| Mandatory stdout format: [START], [STEP], [END] | |
| Environment variables: | |
| API_BASE_URL β LLM API endpoint (required, with default) | |
| MODEL_NAME β Model identifier (required, with default) | |
| HF_TOKEN β Hugging Face API key (no default β injected by HF Spaces) | |
| ENV_BASE_URL β PreferenceLab Space URL (optional, defaults to localhost) | |
| Usage: | |
| python inference.py | |
| HF_TOKEN=hf_xxx MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct python inference.py | |
| """ | |
| import os | |
| import json | |
| from openai import OpenAI | |
| # ββ Mandatory env vars βββββββββββββββββββββββββββββββββββββββββ | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000") | |
| client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL) | |
| from server.environment import PreferenceLabEnvironment | |
| # ββ Mandatory log functions ββββββββββββββββββββββββββββββββββββ | |
| def log_start(task: str, env: str, model: str): | |
| """Stdout START log β required structured format.""" | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error=None): | |
| """Stdout STEP log β required structured format.""" | |
| err = error if error else "null" | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} " | |
| f"done={str(done).lower()} error={err}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: list[float]): | |
| """Stdout END log β required structured format.""" | |
| r_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} " | |
| f"score={score:.2f} rewards={r_str}", | |
| flush=True, | |
| ) | |
| # ββ LLM call ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def call_llm(system: str, user: str) -> str: | |
| """Call the LLM via OpenAI-compatible client. Returns raw text.""" | |
| try: | |
| resp = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user}, | |
| ], | |
| max_tokens=100, | |
| temperature=0.0, # deterministic for reproducibility | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| except Exception as e: | |
| print(f" [LLM ERROR] {e}", flush=True) | |
| return "" | |
| def parse_json(text: str, fallback: dict) -> dict: | |
| """Extract and parse the first JSON object found in text using robust decoding.""" | |
| try: | |
| decoder = json.JSONDecoder() | |
| # Find first '{' to start decoding directly | |
| start = text.find("{") | |
| if start >= 0: | |
| obj, _ = decoder.raw_decode(text[start:]) | |
| return obj | |
| except json.JSONDecodeError: | |
| pass | |
| except Exception: | |
| pass | |
| return fallback | |
| # ββ Task runners βββββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEMS = { | |
| "pairwise": ( | |
| 'You are an expert RLHF annotator. Think step by step before answering.\n' | |
| 'Example: Prompt: "What is 2+2?" A: "4" B: "Five" β {"choice":"A"} because A is factually correct.\n' | |
| 'Reply ONLY with valid JSON: {"choice":"A"} or {"choice":"B"} or {"choice":"tie"}.' | |
| ), | |
| "likert": ( | |
| 'You are an expert RLHF annotator. Think step by step.\n' | |
| 'Rate helpfulness (does it answer?), honesty (is it true?), ' | |
| 'harmlessness (is it safe?), instruction_following (does it follow exactly?).\n' | |
| 'Reply ONLY with JSON: {"helpfulness":4,"honesty":5,"harmlessness":5,"instruction_following":4}' | |
| ), | |
| "consistency": ( | |
| 'You are an expert RLHF annotator. Think step by step.\n' | |
| 'Rank responses by: accuracy first, then completeness, then clarity.\n' | |
| 'Example: If C is most accurate and D is vague β {"ranking":["C","A","B","D"]}\n' | |
| 'Reply ONLY with JSON: {"ranking":["B","A","C","D"]}' | |
| ), | |
| } | |
| def run_task(env, task_type: str, task_name: str) -> float: | |
| """ | |
| Run a full episode for the given task type. | |
| Args: | |
| env: PreferenceLabEnvironment instance. | |
| task_type: 'pairwise' | 'likert' | 'consistency' | |
| task_name: Human-readable name for the [START] log. | |
| Returns: | |
| Average episode reward (float). | |
| """ | |
| import sys | |
| sys.path.insert(0, ".") | |
| from models import PairwiseAction, LikertAction, ConsistencyAction | |
| log_start(task=task_name, env="preference_lab", model=MODEL_NAME) | |
| obs = env.reset(seed=42, task_type=task_type) | |
| rewards: list[float] = [] | |
| steps = 0 | |
| success = False | |
| for step in range(1, 6): | |
| try: | |
| # ββ Build action from LLM output βββββββββββββββββ | |
| if task_type == "pairwise": | |
| user = ( | |
| f"Prompt: {obs.prompt}\n\n" | |
| f"Response A:\n{obs.response_a}\n\n" | |
| f"Response B:\n{obs.response_b}" | |
| ) | |
| out = parse_json(call_llm(SYSTEMS["pairwise"], user), {"choice": "A"}) | |
| choice = out.get("choice", "A") | |
| if choice not in ("A", "B", "tie", "skip"): | |
| choice = "A" | |
| action = PairwiseAction(choice=choice) | |
| action_str = f"choice={choice}" | |
| elif task_type == "likert": | |
| user = ( | |
| f"Prompt: {obs.prompt}\n\n" | |
| f"Response:\n{obs.response}" | |
| ) | |
| out = parse_json( | |
| call_llm(SYSTEMS["likert"], user), | |
| {"helpfulness": 3, "honesty": 3, "harmlessness": 4, "instruction_following": 3}, | |
| ) | |
| def c(v): return max(1, min(5, int(out.get(v, 3)))) # clamp 1-5 | |
| action = LikertAction( | |
| helpfulness=c("helpfulness"), | |
| honesty=c("honesty"), | |
| harmlessness=c("harmlessness"), | |
| instruction_following=c("instruction_following"), | |
| ) | |
| action_str = ( | |
| f"h={c('helpfulness')},ho={c('honesty')}," | |
| f"ha={c('harmlessness')},i={c('instruction_following')}" | |
| ) | |
| else: # consistency | |
| user = ( | |
| f"Prompt: {obs.prompt}\n\n" | |
| f"A: {obs.response_a}\n" | |
| f"B: {obs.response_b}\n" | |
| f"C: {obs.response_c}\n" | |
| f"D: {obs.response_d}" | |
| ) | |
| out = parse_json(call_llm(SYSTEMS["consistency"], user), {"ranking": ["A", "B", "C", "D"]}) | |
| ranking = out.get("ranking", ["A", "B", "C", "D"]) | |
| if not isinstance(ranking, list) or len(ranking) != 4: | |
| ranking = ["A", "B", "C", "D"] | |
| action = ConsistencyAction(ranking=ranking) | |
| action_str = ">".join(ranking) | |
| # ββ Step the environment (returns Observation) ββββ | |
| obs = env.step(action) | |
| reward = obs.reward | |
| done = obs.done | |
| except Exception as e: | |
| reward = 0.0 | |
| done = True | |
| log_step(step=step, action="error", reward=reward, done=done, error=str(e)) | |
| break | |
| rewards.append(reward) | |
| steps = step | |
| log_step(step=step, action=action_str, reward=reward, done=done) | |
| if done: | |
| break | |
| score = sum(rewards) / max(len(rewards), 1) | |
| success = score > 0.0 | |
| log_end(success=success, steps=steps, score=score, rewards=rewards) | |
| return score | |
| # ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| if not HF_TOKEN: | |
| raise SystemExit("HF_TOKEN is required to run baseline inference.") | |
| print("=" * 60, flush=True) | |
| print("PreferenceLab Baseline Inference", flush=True) | |
| print(f"Model: {MODEL_NAME}", flush=True) | |
| print(f"API URL: {API_BASE_URL}", flush=True) | |
| print(f"Env URL: {ENV_BASE_URL}", flush=True) | |
| print("=" * 60, flush=True) | |
| env = PreferenceLabEnvironment() | |
| scores = [] | |
| scores.append(run_task(env, "pairwise", "pairwise-ranking")) | |
| scores.append(run_task(env, "likert", "likert-scoring")) | |
| scores.append(run_task(env, "consistency", "consistency-ranking")) | |
| if scores: | |
| print(f"\nOverall avg: {sum(scores) / len(scores):.2f}", flush=True) | |
| if len(scores) >= 3: | |
| print("\n=== CURRICULUM LEARNING DEMO ===") | |
| print(f"Task 1 Pairwise (Easy): {scores[0]:.2f}") | |
| print(f"Task 2 Likert (Medium): {scores[1]:.2f}") | |
| print(f"Task 3 Consistency (Hard): {scores[2]:.2f}") | |
| print(f"Difficulty progression: {scores[0]:.2f} β {scores[1]:.2f} β {scores[2]:.2f}") | |
| if __name__ == "__main__": | |
| main() | |