Spaces:
Sleeping
Sleeping
Add arcee-ai/prime model
Browse files
models.py
CHANGED
|
@@ -4,39 +4,32 @@ from enum import Enum
|
|
| 4 |
|
| 5 |
|
| 6 |
class OpenRouterModel(str, Enum):
|
| 7 |
-
|
| 8 |
-
CLAUDE_3_OPUS = "claude-3-opus"
|
| 9 |
-
GPT_4O = "gpt-4o"
|
| 10 |
-
GPT_4O_MINI = "gpt-4o-mini"
|
| 11 |
-
LLAMA_31_70B = "llama-3.1-70b"
|
| 12 |
-
LLAMA_31_8B = "llama-3.1-8b"
|
| 13 |
|
| 14 |
|
| 15 |
class SamplerType(str, Enum):
|
| 16 |
CATEGORY = "CATEGORY"
|
| 17 |
UNIFORM = "UNIFORM"
|
| 18 |
GAUSSIAN = "GAUSSIAN"
|
| 19 |
-
UUID = "UUID"
|
| 20 |
-
DATETIME = "DATETIME"
|
| 21 |
|
| 22 |
|
| 23 |
class ColumnConfig(BaseModel):
|
| 24 |
name: str = Field(..., description="Column name")
|
| 25 |
-
type: str = Field(..., description="Column type: sampler, llm_text
|
| 26 |
-
params: dict = Field(default_factory=dict
|
| 27 |
|
| 28 |
|
| 29 |
class GenerateRequest(BaseModel):
|
| 30 |
-
num_records: int = Field(default=10, ge=1, le=
|
| 31 |
-
model: OpenRouterModel = Field(default=OpenRouterModel.
|
| 32 |
columns: list[ColumnConfig] = Field(..., description="Column configurations")
|
| 33 |
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
| 34 |
max_tokens: int = Field(default=512, ge=64, le=4096)
|
| 35 |
|
| 36 |
|
| 37 |
class PreviewRequest(BaseModel):
|
| 38 |
-
model: OpenRouterModel = Field(default=OpenRouterModel.
|
| 39 |
-
columns: list[ColumnConfig] = Field(...
|
| 40 |
temperature: float = Field(default=0.7)
|
| 41 |
max_tokens: int = Field(default=512)
|
| 42 |
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class OpenRouterModel(str, Enum):
|
| 7 |
+
ARCEE_PRIME = "arcee-ai/prime"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class SamplerType(str, Enum):
|
| 11 |
CATEGORY = "CATEGORY"
|
| 12 |
UNIFORM = "UNIFORM"
|
| 13 |
GAUSSIAN = "GAUSSIAN"
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class ColumnConfig(BaseModel):
|
| 17 |
name: str = Field(..., description="Column name")
|
| 18 |
+
type: str = Field(..., description="Column type: sampler, llm_text")
|
| 19 |
+
params: dict = Field(default_factory=dict)
|
| 20 |
|
| 21 |
|
| 22 |
class GenerateRequest(BaseModel):
|
| 23 |
+
num_records: int = Field(default=10, ge=1, le=100)
|
| 24 |
+
model: OpenRouterModel = Field(default=OpenRouterModel.ARCEE_PRIME)
|
| 25 |
columns: list[ColumnConfig] = Field(..., description="Column configurations")
|
| 26 |
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
| 27 |
max_tokens: int = Field(default=512, ge=64, le=4096)
|
| 28 |
|
| 29 |
|
| 30 |
class PreviewRequest(BaseModel):
|
| 31 |
+
model: OpenRouterModel = Field(default=OpenRouterModel.ARCEE_PRIME)
|
| 32 |
+
columns: list[ColumnConfig] = Field(...)
|
| 33 |
temperature: float = Field(default=0.7)
|
| 34 |
max_tokens: int = Field(default=512)
|
| 35 |
|