khushiii02 commited on
Commit
ae6d930
·
verified ·
1 Parent(s): 28657e4

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +26 -9
models.py CHANGED
@@ -1,11 +1,11 @@
1
  """
2
  models.py — Typed Pydantic models for the Support Ticket Agent OpenEnv environment.
3
  Satisfies OpenEnv spec: typed Observation, Action, Reward models.
4
- PHASE 2 FIX: All score fields use gt=0.0, lt=1.0 (strictly between 0 and 1).
5
  """
6
  from __future__ import annotations
7
  from typing import Any, Dict, List, Optional
8
- from pydantic import BaseModel, Field
9
 
10
  VALID_DEPARTMENTS: List[str] = [
11
  "Technical", "Billing", "Product", "IT", "Returns", "Sales", "HR"
@@ -26,20 +26,37 @@ class TicketObservation(BaseModel):
26
 
27
  class TicketAction(BaseModel):
28
  department: str = Field(..., description="One of the 7 valid departments")
29
- priority: int = Field(2, ge=1, le=3, description="1=Low 2=Medium 3=High")
30
  reply: Optional[str] = Field("", description="Draft first reply (Task 3 only)")
31
 
32
 
33
  class TicketReward(BaseModel):
34
- # CRITICAL: gt/lt (not ge/le) — strictly between 0 and 1
35
- score: float = Field(..., gt=0.0, lt=1.0)
36
- department_score: float = Field(..., gt=0.0, lt=1.0)
37
- priority_score: float = Field(..., gt=0.0, lt=1.0)
38
- reply_score: float = Field(..., gt=0.0, lt=1.0)
 
 
 
39
  feedback: str
40
  done: bool
41
  correct_department: Optional[str] = None
42
- correct_priority: Optional[int] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  class EnvState(BaseModel):
 
1
  """
2
  models.py — Typed Pydantic models for the Support Ticket Agent OpenEnv environment.
3
  Satisfies OpenEnv spec: typed Observation, Action, Reward models.
4
+ PHASE 2 FIX: Using ge=0.001, le=0.999 to be more permissive while still ensuring (0,1) range.
5
  """
6
  from __future__ import annotations
7
  from typing import Any, Dict, List, Optional
8
+ from pydantic import BaseModel, Field, field_validator
9
 
10
  VALID_DEPARTMENTS: List[str] = [
11
  "Technical", "Billing", "Product", "IT", "Returns", "Sales", "HR"
 
26
 
27
  class TicketAction(BaseModel):
28
  department: str = Field(..., description="One of the 7 valid departments")
29
+ priority: int = Field(2, ge=1, le=3, description="1=Low 2=Medium 3=High")
30
  reply: Optional[str] = Field("", description="Draft first reply (Task 3 only)")
31
 
32
 
33
  class TicketReward(BaseModel):
34
+ """
35
+ Reward model with scores strictly between 0 and 1.
36
+ Using validators to ensure compliance.
37
+ """
38
+ score: float = Field(..., description="Overall score strictly between 0 and 1")
39
+ department_score: float = Field(..., description="Department classification score")
40
+ priority_score: float = Field(..., description="Priority classification score")
41
+ reply_score: float = Field(..., description="Reply quality score")
42
  feedback: str
43
  done: bool
44
  correct_department: Optional[str] = None
45
+ correct_priority: Optional[int] = None
46
+
47
+ @field_validator('score', 'department_score', 'priority_score', 'reply_score', mode='before')
48
+ @classmethod
49
+ def clamp_score(cls, v):
50
+ """Ensure all scores are strictly between 0 and 1."""
51
+ if v is None:
52
+ return 0.5 # Default neutral score
53
+ v = float(v)
54
+ # Clamp to strictly within (0, 1)
55
+ if v <= 0.0:
56
+ return 0.01
57
+ if v >= 1.0:
58
+ return 0.99
59
+ return round(v, 4)
60
 
61
 
62
  class EnvState(BaseModel):