rust_coder / server /app.py
Parthiban007's picture
Upload folder using huggingface_hub
7bc8744 verified
raw
history blame
7.12 kB
"""
FastAPI application for the Rust Coder Environment.
Endpoints:
POST /reset — Start new episode (loads next problem)
POST /step — Submit Rust code for evaluation
GET /state — Get current episode state
GET /schema — Action/observation JSON schemas
WS /ws — WebSocket for persistent sessions
"""
import os
import gradio as gr
from openai import OpenAI
from dotenv import load_dotenv
from openenv.core.env_server.http_server import create_app
from models import RustCoderAction, RustCoderObservation
from server.rust_coder_environment import RustCoderEnvironment
load_dotenv()
# --- Core OpenEnv Server Setup ---
# Use a distinct name for the OpenEnv FastAPI instance
openenv_app = create_app(
RustCoderEnvironment,
RustCoderAction,
RustCoderObservation,
env_name="rust_coder",
max_concurrent_envs=1,
)
# Add a health check endpoint for Docker directly to the base app
@openenv_app.get("/health")
async def health_check():
return {"status": "healthy"}
# --- Shared Logic ---
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
def get_llm_solution(problem_desc: str):
"""Call LLM to get a Rust solution"""
try:
client_llm = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
completion = client_llm.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "You are an expert Rust developer. Respond ONLY with the code solution, no explanation."},
{"role": "user", "content": f"Fix the following Rust problem:\n{problem_desc}"},
],
temperature=0.2,
)
text = (completion.choices[0].message.content or "").strip()
# Clean markdown code blocks
if "```rust" in text:
text = text.split("```rust")[1].split("```")[0]
elif "```" in text:
text = text.split("```")[1].split("```")[0]
return text.strip()
except Exception as e:
return f"// LLM Error: {e}"
def evaluate_single(problem_id, code=None):
"""Run evaluation for a specific problem. If code is None, it asks the LLM."""
try:
idx = int(problem_id.split(":")[0]) - 1
problem = RustCoderEnvironment().problems[idx]
# 1. Get code from LLM if not provided
solution_code = code if code else get_llm_solution(problem["description"])
# 2. Guard: If LLM failed, do not evaluate
if solution_code.startswith("// LLM Error"):
return solution_code, {"error": "LLM failed to generate a solution. Check your HF_TOKEN."}
# 3. Evaluate properly
env = RustCoderEnvironment()
# Reset to the specifically requested index
state = env.reset(start_index=idx)
state = env.step(RustCoderAction(code=solution_code))
metrics = {
"Total Reward": f"{state.reward:.2f}",
"Compilation": "Success" if state.compilation_success else "Failed",
"Metrics": state.reward_breakdown
}
return solution_code, metrics
except Exception as e:
return f"// Error: {e}", {"error": f"Evaluation system error: {e}"}
def run_benchmark(progress=gr.Progress()):
"""Run all 10 problems through the LLM and show summary"""
try:
env = RustCoderEnvironment()
rows = []
total_score = 0.0
# Check if token is actually present
test_token = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
if not test_token:
return "## Error: HF_TOKEN is not set. Add it to your HF Space secrets or local .env file.", []
for i in range(len(env.problems)):
progress(i/len(env.problems), desc=f"Benchmarking Task {i+1}...")
problem = env.problems[i]
code = get_llm_solution(problem["description"])
reward = 0.0
compiled = "Failed (LLM Error)"
if not code.startswith("// LLM Error"):
env.reset(start_index=i)
state = env.step(RustCoderAction(code=code))
reward = state.reward
compiled = "Success" if state.compilation_success else "Failed"
rows.append([problem["id"], problem["title"], problem.get("difficulty", "N/A"), f"{reward:.2f}", compiled])
total_score += reward
avg_score = total_score / len(env.problems)
summary_md = f"## Benchmark Summary\n**Final Environment Score: {avg_score:.2f} / 1.0**"
return summary_md, rows
except Exception as e:
return f"### Benchmark Error: {e}", []
# --- Build the Gradio UI ---
def create_dashboard():
with gr.Blocks(title="Rust Coder Evaluation Dashboard") as demo:
gr.Markdown("# 🦀 Rust Coder: LLM Evaluation Dashboard")
with gr.Tab("Individual Task Evaluation"):
with gr.Row():
with gr.Column(scale=1):
p_env = RustCoderEnvironment()
p_list = [f"{p['id']}: {p['title']} ({p.get('difficulty', 'N/A')})" for p in p_env.problems]
dropdown = gr.Dropdown(choices=p_list, label="Select Question", value=p_list[0])
desc = gr.Markdown(value=f"### Question [{p_env.problems[0].get('difficulty', 'N/A')}]\n{p_env.problems[0]['description']}")
with gr.Column(scale=1):
run_llm_btn = gr.Button("Generate Solution & Evaluate", variant="primary")
code_display = gr.Code(label="AI Generated Solution", interactive=False)
results_json = gr.JSON(label="Metric Breakdown")
def update_desc(p_str):
idx = int(p_str.split(":")[0]) - 1
p = p_env.problems[idx]
return f"### Question [{p.get('difficulty', 'N/A')}]\n{p['description']}", "" # Clear solution on change
dropdown.change(update_desc, inputs=[dropdown], outputs=[desc, code_display])
run_llm_btn.click(evaluate_single, inputs=[dropdown], outputs=[code_display, results_json])
with gr.Tab("Full Environment Benchmark"):
gr.Markdown("### Complete Environment Suite")
gr.Markdown("Runs the LLM against all 10 tasks sequentially to determine the global OpenEnv score.")
b_summarize = gr.Button("Run Performance Benchmark", variant="stop")
b_sum = gr.Markdown()
b_grid = gr.Dataframe(headers=["ID", "Title", "Difficulty", "Reward", "Compiled"], label="Task Results")
b_summarize.click(run_benchmark, outputs=[b_sum, b_grid])
return demo
# Final consolidated Gradio App mounted on the FastAPI server
app = gr.mount_gradio_app(openenv_app, create_dashboard(), path="/")
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
"""Entry point: uv run server or python -m server.app"""
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()