Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Data models for the kernrl GPU kernel optimization environment. | |
| The kernrl environment enables training LLMs to write optimized CUDA/Triton | |
| kernels by providing real hardware feedback. | |
| """ | |
| from typing import Optional | |
| from pydantic import Field | |
| # Support both in-repo and standalone imports | |
| try: | |
| # In-repo imports (when running from OpenEnv repository) | |
| from openenv.core.env_server.types import Action, Observation, State | |
| except ImportError: | |
| # Standalone imports (when environment is standalone with openenv from pip) | |
| from openenv.core.env_server.types import Action, Observation, State | |
| class KernelAction(Action): | |
| """Action for the kernrl environment - kernel code submission.""" | |
| code: str = Field(..., min_length=1, description="The CUDA/Triton kernel code to evaluate") | |
| class KernelObservation(Observation): | |
| """Observation from the kernrl environment - evaluation results.""" | |
| problem_id: str = Field(..., description="Unique identifier for the problem") | |
| problem_description: str = Field(..., description="Full problem description with requirements") | |
| reference_code: str = Field(..., description="PyTorch reference implementation") | |
| gpu_info: str = Field(..., description="GPU device information") | |
| turn: int = Field(default=0, ge=0, description="Current turn number") | |
| max_turns: int = Field(default=10, ge=1, description="Maximum turns allowed") | |
| feedback: str = Field(default="", description="Evaluation feedback for the agent") | |
| # Evaluation results | |
| compilation_success: bool = Field(default=False, description="Whether the code compiled successfully") | |
| compilation_error: Optional[str] = Field(default=None, description="Compilation error message if failed") | |
| correctness_pass: Optional[bool] = Field(default=None, description="Whether output matches reference") | |
| max_diff: Optional[float] = Field(default=None, description="Maximum difference from reference output") | |
| speedup: Optional[float] = Field(default=None, description="Speedup factor vs PyTorch baseline") | |
| class KernelState(State): | |
| """State for the kernrl environment.""" | |
| problem_id: Optional[str] = Field(default=None, description="Current problem ID") | |
| turn: int = Field(default=0, ge=0, description="Current turn number") | |
| max_turns: int = Field(default=10, ge=1, description="Maximum turns allowed") | |
| best_speedup: float = Field(default=0.0, ge=0.0, description="Best speedup achieved so far") | |
| solved: bool = Field(default=False, description="Whether the problem has been solved") | |