kernrl-v2-1-0 / models.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
4f1731d verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the kernrl GPU kernel optimization environment.
The kernrl environment enables training LLMs to write optimized CUDA/Triton
kernels by providing real hardware feedback.
"""
from typing import Optional
from pydantic import Field
# Support both in-repo and standalone imports
try:
# In-repo imports (when running from OpenEnv repository)
from openenv.core.env_server.types import Action, Observation, State
except ImportError:
# Standalone imports (when environment is standalone with openenv from pip)
from openenv.core.env_server.types import Action, Observation, State
class KernelAction(Action):
"""Action for the kernrl environment - kernel code submission."""
code: str = Field(..., min_length=1, description="The CUDA/Triton kernel code to evaluate")
class KernelObservation(Observation):
"""Observation from the kernrl environment - evaluation results."""
problem_id: str = Field(..., description="Unique identifier for the problem")
problem_description: str = Field(..., description="Full problem description with requirements")
reference_code: str = Field(..., description="PyTorch reference implementation")
gpu_info: str = Field(..., description="GPU device information")
turn: int = Field(default=0, ge=0, description="Current turn number")
max_turns: int = Field(default=10, ge=1, description="Maximum turns allowed")
feedback: str = Field(default="", description="Evaluation feedback for the agent")
# Evaluation results
compilation_success: bool = Field(default=False, description="Whether the code compiled successfully")
compilation_error: Optional[str] = Field(default=None, description="Compilation error message if failed")
correctness_pass: Optional[bool] = Field(default=None, description="Whether output matches reference")
max_diff: Optional[float] = Field(default=None, description="Maximum difference from reference output")
speedup: Optional[float] = Field(default=None, description="Speedup factor vs PyTorch baseline")
class KernelState(State):
"""State for the kernrl environment."""
problem_id: Optional[str] = Field(default=None, description="Current problem ID")
turn: int = Field(default=0, ge=0, description="Current turn number")
max_turns: int = Field(default=10, ge=1, description="Maximum turns allowed")
best_speedup: float = Field(default=0.0, ge=0.0, description="Best speedup achieved so far")
solved: bool = Field(default=False, description="Whether the problem has been solved")