mindchain commited on
Commit
47ab43e
·
verified ·
1 Parent(s): 410f2c5

Update models.py for OpenRouter models

Browse files
Files changed (1) hide show
  1. models.py +9 -6
models.py CHANGED
@@ -3,10 +3,13 @@ from typing import Optional
3
  from enum import Enum
4
 
5
 
6
- class ZaiModel(str, Enum):
7
- GLM_5 = "glm-5"
8
- GLM_47 = "glm-4.7"
9
- GLM_45_AIR = "glm-4.5-air"
 
 
 
10
 
11
 
12
  class SamplerType(str, Enum):
@@ -25,14 +28,14 @@ class ColumnConfig(BaseModel):
25
 
26
  class GenerateRequest(BaseModel):
27
  num_records: int = Field(default=10, ge=1, le=1000, description="Number of records to generate")
28
- model: ZaiModel = Field(default=ZaiModel.GLM_47, description="z.ai model to use")
29
  columns: list[ColumnConfig] = Field(..., description="Column configurations")
30
  temperature: float = Field(default=0.7, ge=0.0, le=2.0)
31
  max_tokens: int = Field(default=512, ge=64, le=4096)
32
 
33
 
34
  class PreviewRequest(BaseModel):
35
- model: ZaiModel = Field(default=ZaiModel.GLM_47)
36
  columns: list[ColumnConfig] = Field(..., description="Column configurations")
37
  temperature: float = Field(default=0.7)
38
  max_tokens: int = Field(default=512)
 
3
  from enum import Enum
4
 
5
 
6
+ class OpenRouterModel(str, Enum):
7
+ CLAUDE_35_SONNET = "claude-3.5-sonnet"
8
+ CLAUDE_3_OPUS = "claude-3-opus"
9
+ GPT_4O = "gpt-4o"
10
+ GPT_4O_MINI = "gpt-4o-mini"
11
+ LLAMA_31_70B = "llama-3.1-70b"
12
+ LLAMA_31_8B = "llama-3.1-8b"
13
 
14
 
15
  class SamplerType(str, Enum):
 
28
 
29
  class GenerateRequest(BaseModel):
30
  num_records: int = Field(default=10, ge=1, le=1000, description="Number of records to generate")
31
+ model: OpenRouterModel = Field(default=OpenRouterModel.GPT_4O_MINI, description="OpenRouter model to use")
32
  columns: list[ColumnConfig] = Field(..., description="Column configurations")
33
  temperature: float = Field(default=0.7, ge=0.0, le=2.0)
34
  max_tokens: int = Field(default=512, ge=64, le=4096)
35
 
36
 
37
  class PreviewRequest(BaseModel):
38
+ model: OpenRouterModel = Field(default=OpenRouterModel.GPT_4O_MINI)
39
  columns: list[ColumnConfig] = Field(..., description="Column configurations")
40
  temperature: float = Field(default=0.7)
41
  max_tokens: int = Field(default=512)