rust_coder / inference.py
Parthiban007's picture
Upload folder using huggingface_hub
0b15484 verified
raw
history blame
5.32 kB
import os
import re
import json
import asyncio
import logging
from typing import List, Optional
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
# --- Logging (inference.py) ---
_LOG_LEVEL = (os.getenv("LOG_LEVEL") or "INFO").upper()
logging.basicConfig(
level=getattr(logging, _LOG_LEVEL, logging.INFO),
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
logger = logging.getLogger("rust_coder.inference")
# --- Competition Configuration ---
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
ENV_URL = os.getenv("ENV_URL") or "http://localhost:8000"
# Episode constants: 10 problems, each worth max reward 1.0
MAX_STEPS = 10
MAX_TOTAL_REWARD = 10.0
SUCCESS_SCORE_THRESHOLD = 0.5
# Import client (ensure rust_coder is in PYTHONPATH)
from client import RustCoderEnv
from models import RustCoderAction
# --- Strict Logging Helpers ---
def log_start(task: str, env: str, model: str):
print(f'[START] task="{task}" env="{env}" model="{model}"', flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str] = None):
escaped_action = action.replace('\n', '\\n')[:100] + "..."
log_line = f'[STEP] step={step} action="{escaped_action}" reward={reward:.4f} done={str(done).lower()}'
if error:
log_line += f' error="{error}"'
print(log_line, flush=True)
def log_end(success: bool, steps: int, score: float, rewards: List[float]):
print(f'[END] success={str(success).lower()} steps={steps} score={score:.4f} rewards={json.dumps(rewards)}', flush=True)
# --- LLM Solution Logic ---
async def get_model_code(prompt: str, client: OpenAI) -> str:
"""Call the LLM to get a Rust solution."""
try:
logger.info(
"LLM call start model=%s base_url=%s prompt_chars=%d token_present=%s",
MODEL_NAME,
API_BASE_URL,
len(prompt or ""),
bool(HF_TOKEN),
)
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "You are a senior Rust systems engineer. Return ONLY the complete, fixed Rust code. No explanation."},
{"role": "user", "content": prompt},
],
temperature=0.1,
)
text = (completion.choices[0].message.content or "").strip()
logger.debug("LLM raw response chars=%d", len(text))
# Extract code from markdown
if "```rust" in text:
text = text.split("```rust")[1].split("```")[0]
elif "```" in text:
text = text.split("```")[1].split("```")[0]
text = text.strip()
if not text:
logger.warning("LLM returned empty code after cleanup.")
return "// Error: empty response (no code returned)."
logger.info("LLM call end: returned_code_chars=%d", len(text))
return text
except Exception as e:
logger.exception("LLM Request failed.")
return f"// Error: {e}"
# --- Main Evaluation Loop ---
async def main():
if not HF_TOKEN:
logger.error("HF_TOKEN/API_KEY not found in environment.")
return
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
env = RustCoderEnv(base_url=ENV_URL)
log_start(task="rust_coder", env="RustCoder-v1", model=MODEL_NAME)
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
try:
# Start the single episode (10 problems)
result = await env.reset()
obs = result.observation
for step in range(1, MAX_STEPS + 1):
if result.done:
break
steps_taken = step
# Format prompt including starter code if available
prompt = obs.problem_description
if obs.starter_code:
prompt += f"\n\nStarter Code:\n```rust\n{obs.starter_code}\n```"
# 1. Ask model for solution to current task
code_solution = await get_model_code(prompt, client)
# 2. Environment step
logger.debug("Submitting to env.step code_chars=%d", len(code_solution or ""))
result = await env.step(RustCoderAction(code=code_solution))
obs = result.observation
reward = result.reward or 0.0
done = result.done
rewards.append(reward)
log_step(step=step, action=code_solution, reward=reward, done=done)
if done:
break
# Normalize score to [0, 1] matching sample format
score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
score = min(max(score, 0.0), 1.0)
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as e:
logger.exception("Runtime error.")
log_step(step=steps_taken + 1, action="error", reward=0.0, done=True, error=str(e))
finally:
try:
await env.close()
except Exception as e:
logger.exception("env.close() error.")
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
if __name__ == "__main__":
asyncio.run(main())