XciD's picture
XciD HF Staff
fix: support adaptive thinking type in request deserialization
f160ece
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::value::RawValue;
// Constants
pub const HEADER_API_KEY: &str = "x-api-key";
// Error types
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Error {
#[serde(rename = "type")]
pub content_type: String,
pub error: InnerError,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InnerError {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamError {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}
// Request types
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateMessageRequest {
#[serde(default, skip_serializing_if = "MessageContents::is_empty")]
pub system: MessageContents,
pub model: String,
pub messages: Vec<Message>,
pub max_tokens: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<Metadata>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<Thinking>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(default, skip_serializing_if = "is_zero_f64")]
pub temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(default = "default_true")]
pub stream: bool,
}
fn is_zero_f64(v: &f64) -> bool {
*v == 0.0
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CountTokensRequest {
#[serde(default, skip_serializing_if = "MessageContents::is_empty")]
pub system: MessageContents,
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<Thinking>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CountTokensResponse {
pub input_tokens: i64,
}
// Message types
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
pub message_type: Option<MessageType>,
pub role: MessageRole,
pub content: MessageContents,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_sequence: Option<String>,
#[serde(default)]
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageType {
Message,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
User,
Assistant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
PauseTurn,
Refusal,
}
// MessageContents - can be string or array
#[derive(Debug, Clone, Default)]
pub struct MessageContents(pub Vec<MessageContent>);
impl MessageContents {
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn iter(&self) -> impl Iterator<Item = &MessageContent> {
self.0.iter()
}
}
impl Serialize for MessageContents {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.0.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for MessageContents {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
let value = serde_json::Value::deserialize(deserializer)?;
match &value {
serde_json::Value::Null => Ok(MessageContents(vec![])),
serde_json::Value::String(s) => Ok(MessageContents(vec![MessageContent {
content_type: MessageContentType::Text,
text: Some(s.clone()),
..Default::default()
}])),
serde_json::Value::Array(_) => {
let arr: Vec<MessageContent> = serde_json::from_value(value.clone())
.map_err(|e| {
eprintln!("Failed to deserialize MessageContents array: {}", e);
eprintln!("Value: {}", serde_json::to_string_pretty(&value).unwrap_or_default());
D::Error::custom(e.to_string())
})?;
Ok(MessageContents(arr))
}
_ => {
eprintln!("Unexpected MessageContents type: {}", serde_json::to_string_pretty(&value).unwrap_or_default());
Err(D::Error::custom("content must be null, string, or array"))
}
}
}
}
impl IntoIterator for MessageContents {
type Item = MessageContent;
type IntoIter = std::vec::IntoIter<MessageContent>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'a> IntoIterator for &'a MessageContents {
type Item = &'a MessageContent;
type IntoIter = std::slice::Iter<'a, MessageContent>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum MessageContentType {
#[default]
Text,
Image,
ToolUse,
ToolResult,
Thinking,
RedactedThinking,
ServerToolUse,
WebSearchToolResult,
WebSearchResult,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MessageContent {
#[serde(rename = "type")]
pub content_type: MessageContentType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub source: Option<MessageContentSource>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thinking: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub signature: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub data: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input: Option<Box<RawValue>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_use_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<MessageContents>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub encrypted_content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub page_age: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub citations: Option<Vec<Citation>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageContentSource {
#[serde(rename = "type")]
pub source_type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub media_type: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub data: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheControl {
#[serde(rename = "type")]
pub cache_type: CacheControlType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ttl: Option<CacheControlTTL>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CacheControlType {
Ephemeral,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CacheControlTTL {
#[serde(rename = "5m")]
FiveMinutes,
#[serde(rename = "1h")]
OneHour,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Citation {
#[serde(rename = "type")]
pub citation_type: CitationType,
pub url: String,
pub title: String,
pub encrypted_index: String,
pub cited_text: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CitationType {
WebSearchResultLocation,
}
// Delta types for streaming
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageContentDeltaType {
TextDelta,
InputJsonDelta,
ThinkingDelta,
SignatureDelta,
CitationsDelta,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageContentDelta {
#[serde(rename = "type")]
pub delta_type: MessageContentDeltaType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub partial_json: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thinking: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub signature: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub citation: Option<Citation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Metadata {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Thinking {
#[serde(rename = "type")]
pub thinking_type: ThinkingType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub budget_tokens: Option<i32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ThinkingType {
Enabled,
Disabled,
Adaptive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoice {
#[serde(rename = "type")]
pub choice_type: ToolChoiceType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(default, skip_serializing_if = "is_false")]
pub disable_parallel_tool_use: bool,
}
fn is_false(v: &bool) -> bool {
!*v
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoiceType {
Tool,
Auto,
None,
Any,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
pub tool_type: Option<ToolType>,
pub name: String,
#[serde(default)]
pub description: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input_schema: Option<Box<RawValue>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
#[serde(default, skip_serializing_if = "is_zero_i32")]
pub max_uses: i32,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub allowed_domains: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub blocked_domains: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user_location: Option<ToolLocation>,
}
fn is_zero_i32(v: &i32) -> bool {
*v == 0
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ToolType {
#[serde(rename = "custom")]
Custom,
#[serde(rename = "web_search_20250305")]
WebSearch2025,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolLocation {
#[serde(rename = "type")]
pub location_type: ToolLocationType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub city: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub region: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub country: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timezone: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolLocationType {
Approximate,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: i64,
pub output_tokens: i64,
#[serde(default)]
pub cache_read_input_tokens: i64,
#[serde(default)]
pub cache_creation_input_tokens: i64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_creation: Option<CacheCreationUsage>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub server_tool_use: Option<ServerToolUseUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheCreationUsage {
pub ephemeral_5m_input_tokens: i64,
pub ephemeral_1h_input_tokens: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerToolUseUsage {
pub web_search_requests: i32,
}
// SSE Event types
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EventType {
Ping,
Error,
MessageStart,
MessageDelta,
MessageStop,
ContentBlockStart,
ContentBlockDelta,
ContentBlockStop,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Event {
Ping,
Error {
error: StreamError,
},
MessageStart {
message: Message,
},
MessageDelta {
delta: MessageDelta,
#[serde(skip_serializing_if = "Option::is_none")]
usage: Option<Usage>,
},
MessageStop,
ContentBlockStart {
index: i32,
content_block: MessageContent,
},
ContentBlockDelta {
index: i32,
delta: MessageContentDelta,
},
ContentBlockStop {
index: i32,
},
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MessageDelta {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_sequence: Option<String>,
}
impl Event {
pub fn event_type(&self) -> EventType {
match self {
Event::Ping => EventType::Ping,
Event::Error { .. } => EventType::Error,
Event::MessageStart { .. } => EventType::MessageStart,
Event::MessageDelta { .. } => EventType::MessageDelta,
Event::MessageStop => EventType::MessageStop,
Event::ContentBlockStart { .. } => EventType::ContentBlockStart,
Event::ContentBlockDelta { .. } => EventType::ContentBlockDelta,
Event::ContentBlockStop { .. } => EventType::ContentBlockStop,
}
}
pub fn to_sse_string(&self) -> String {
let event_name = match self.event_type() {
EventType::Ping => "ping",
EventType::Error => "error",
EventType::MessageStart => "message_start",
EventType::MessageDelta => "message_delta",
EventType::MessageStop => "message_stop",
EventType::ContentBlockStart => "content_block_start",
EventType::ContentBlockDelta => "content_block_delta",
EventType::ContentBlockStop => "content_block_stop",
};
let data = serde_json::to_string(self).unwrap_or_default();
format!("event: {}\ndata: {}\n\n", event_name, data)
}
}