ROCmPort-AI / backend /models.py
tazwarrrr's picture
v2
984e3c2
from pydantic import BaseModel
from typing import Optional, List
from enum import Enum
class AgentStatus(str, Enum):
WAITING = "waiting"
RUNNING = "running"
DONE = "done"
FAILED = "failed"
RETRYING = "retrying"
class WorkloadType(str, Enum):
COMPUTE_BOUND = "compute-bound"
MEMORY_BOUND = "memory-bound"
UNKNOWN = "unknown"
class PortRequest(BaseModel):
cuda_code: str
kernel_name: Optional[str] = "custom"
simple_mode: Optional[bool] = False # For "Explain Like I'm 5" feature
class ColdStartRequest(BaseModel):
cuda_code: str
kernel_name: Optional[str] = "unknown_input"
class AggregateMetricsRequest(BaseModel):
kernel_names: Optional[List[str]] = None
class AgentEvent(BaseModel):
agent: str # analyzer | translator | optimizer | tester | coordinator
status: AgentStatus
message: str
detail: Optional[str] = None
class VerificationResult(BaseModel):
compiled_successfully: bool
executed_without_error: bool
output_matches_expected: bool
checksum_computed: Optional[str] = None
expected_checksum: Optional[str] = None
actual_checksum: Optional[str] = None
mock_mode: Optional[bool] = False
class CostEstimate(BaseModel):
manual_porting_weeks: str
rocmport_minutes: str
estimated_savings: str
complexity_factor: str # Low | Medium | High
class RiskItem(BaseModel):
"""One flagged pattern found by the pure-Python static scanner."""
line: Optional[int] = None # 1-indexed source line, None if not determinable
pattern: str # The matched text / pattern name
risk_level: str # CRITICAL | HIGH | MEDIUM
description: str # Human-readable explanation
amd_fix_hint: str # Concrete fix for AMD wavefront-64
class StaticRiskReport(BaseModel):
"""Aggregated output of the static wavefront correctness scanner."""
items: List[RiskItem]
critical_count: int
high_count: int
medium_count: int
scan_duration_ms: float # Transparency: shows this runs in <5ms
class AnalyzerResult(BaseModel):
kernels_found: List[str]
cuda_apis: List[str]
warp_size_issue: bool
warp_size_detail: Optional[str]
workload_type: WorkloadType
sharding_detected: bool
difficulty: str # Easy | Medium | Hard
difficulty_reason: str
prediction: Optional[str] = None # 🧠 Prediction field
line_count: Optional[int] = None
complexity_score: Optional[int] = None
static_risk_report: Optional[StaticRiskReport] = None
class TranslatorResult(BaseModel):
hip_code: str
total_changes: int
hipify_changes: int
llm_changes: int
diff_lines: List[dict] # [{line, old, new, confidence, source}]
class OptimizerResult(BaseModel):
optimized_code: str
changes: List[dict] # [{description, impact}]
iteration: int
class TesterResult(BaseModel):
success: bool
iteration: int
speedup: float # vs baseline HIP
bandwidth_utilized: float # percentage
execution_ms: float
bottleneck: str
notes: str
# Trust layer verification
verification: Optional[VerificationResult] = None
data_source: Optional[str] = None
class FinalReport(BaseModel):
migration_success: bool
speedup: float
bandwidth_utilized: float
total_changes: int
bottleneck: str
amd_advantage_explanation: str
iterations: int
hip_code: str
optimized_code: str
verification: Optional[VerificationResult] = None
cost_estimate: Optional[CostEstimate] = None # 💰 Cost impact estimator
# For "Explain Like I'm 5" mode
simplified_explanation: Optional[str] = None
# Static risk data surfaced in final report
static_risk_report: Optional[StaticRiskReport] = None
# Data provenance: real_rocm | demo_artifact | simulated
data_source: str = "simulated"