Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| use serde::{Deserialize, Deserializer, Serialize, Serializer}; | |
| use serde_json::value::RawValue; | |
| // Constants | |
| pub const HEADER_API_KEY: &str = "x-api-key"; | |
| // Error types | |
| pub struct Error { | |
| pub content_type: String, | |
| pub error: InnerError, | |
| } | |
| pub struct InnerError { | |
| pub error_type: String, | |
| pub message: String, | |
| } | |
| pub struct StreamError { | |
| pub error_type: String, | |
| pub message: String, | |
| } | |
| // Request types | |
| pub struct GenerateMessageRequest { | |
| pub system: MessageContents, | |
| pub model: String, | |
| pub messages: Vec<Message>, | |
| pub max_tokens: i32, | |
| pub metadata: Option<Metadata>, | |
| pub stop_sequences: Vec<String>, | |
| pub thinking: Option<Thinking>, | |
| pub tool_choice: Option<ToolChoice>, | |
| pub tools: Vec<Tool>, | |
| pub temperature: f64, | |
| pub top_k: Option<i32>, | |
| pub top_p: Option<f64>, | |
| pub stream: bool, | |
| } | |
| fn is_zero_f64(v: &f64) -> bool { | |
| *v == 0.0 | |
| } | |
| fn default_true() -> bool { | |
| true | |
| } | |
| pub struct CountTokensRequest { | |
| pub system: MessageContents, | |
| pub model: String, | |
| pub messages: Vec<Message>, | |
| pub thinking: Option<Thinking>, | |
| pub tool_choice: Option<ToolChoice>, | |
| pub tools: Vec<Tool>, | |
| } | |
| pub struct CountTokensResponse { | |
| pub input_tokens: i64, | |
| } | |
| // Message types | |
| pub struct Message { | |
| pub id: Option<String>, | |
| pub message_type: Option<MessageType>, | |
| pub role: MessageRole, | |
| pub content: MessageContents, | |
| pub model: Option<String>, | |
| pub stop_reason: Option<StopReason>, | |
| pub stop_sequence: Option<String>, | |
| pub usage: Option<Usage>, | |
| } | |
| pub enum MessageType { | |
| Message, | |
| } | |
| pub enum MessageRole { | |
| User, | |
| Assistant, | |
| } | |
| pub enum StopReason { | |
| EndTurn, | |
| MaxTokens, | |
| StopSequence, | |
| ToolUse, | |
| PauseTurn, | |
| Refusal, | |
| } | |
| // MessageContents - can be string or array | |
| pub struct MessageContents(pub Vec<MessageContent>); | |
| impl MessageContents { | |
| pub fn is_empty(&self) -> bool { | |
| self.0.is_empty() | |
| } | |
| pub fn len(&self) -> usize { | |
| self.0.len() | |
| } | |
| pub fn iter(&self) -> impl Iterator<Item = &MessageContent> { | |
| self.0.iter() | |
| } | |
| } | |
| impl Serialize for MessageContents { | |
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |
| where | |
| S: Serializer, | |
| { | |
| self.0.serialize(serializer) | |
| } | |
| } | |
| impl<'de> Deserialize<'de> for MessageContents { | |
| fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | |
| where | |
| D: Deserializer<'de>, | |
| { | |
| use serde::de::Error; | |
| let value = serde_json::Value::deserialize(deserializer)?; | |
| match &value { | |
| serde_json::Value::Null => Ok(MessageContents(vec![])), | |
| serde_json::Value::String(s) => Ok(MessageContents(vec![MessageContent { | |
| content_type: MessageContentType::Text, | |
| text: Some(s.clone()), | |
| ..Default::default() | |
| }])), | |
| serde_json::Value::Array(_) => { | |
| let arr: Vec<MessageContent> = serde_json::from_value(value.clone()) | |
| .map_err(|e| { | |
| eprintln!("Failed to deserialize MessageContents array: {}", e); | |
| eprintln!("Value: {}", serde_json::to_string_pretty(&value).unwrap_or_default()); | |
| D::Error::custom(e.to_string()) | |
| })?; | |
| Ok(MessageContents(arr)) | |
| } | |
| _ => { | |
| eprintln!("Unexpected MessageContents type: {}", serde_json::to_string_pretty(&value).unwrap_or_default()); | |
| Err(D::Error::custom("content must be null, string, or array")) | |
| } | |
| } | |
| } | |
| } | |
| impl IntoIterator for MessageContents { | |
| type Item = MessageContent; | |
| type IntoIter = std::vec::IntoIter<MessageContent>; | |
| fn into_iter(self) -> Self::IntoIter { | |
| self.0.into_iter() | |
| } | |
| } | |
| impl<'a> IntoIterator for &'a MessageContents { | |
| type Item = &'a MessageContent; | |
| type IntoIter = std::slice::Iter<'a, MessageContent>; | |
| fn into_iter(self) -> Self::IntoIter { | |
| self.0.iter() | |
| } | |
| } | |
| pub enum MessageContentType { | |
| Text, | |
| Image, | |
| ToolUse, | |
| ToolResult, | |
| Thinking, | |
| RedactedThinking, | |
| ServerToolUse, | |
| WebSearchToolResult, | |
| WebSearchResult, | |
| } | |
| pub struct MessageContent { | |
| pub content_type: MessageContentType, | |
| pub text: Option<String>, | |
| pub source: Option<MessageContentSource>, | |
| pub thinking: Option<String>, | |
| pub signature: Option<String>, | |
| pub data: Option<String>, | |
| pub id: Option<String>, | |
| pub name: Option<String>, | |
| pub input: Option<Box<RawValue>>, | |
| pub tool_use_id: Option<String>, | |
| pub content: Option<MessageContents>, | |
| pub title: Option<String>, | |
| pub url: Option<String>, | |
| pub encrypted_content: Option<String>, | |
| pub page_age: Option<String>, | |
| pub citations: Option<Vec<Citation>>, | |
| pub cache_control: Option<CacheControl>, | |
| } | |
| pub struct MessageContentSource { | |
| pub source_type: String, | |
| pub media_type: Option<String>, | |
| pub data: Option<String>, | |
| } | |
| pub struct CacheControl { | |
| pub cache_type: CacheControlType, | |
| pub ttl: Option<CacheControlTTL>, | |
| } | |
| pub enum CacheControlType { | |
| Ephemeral, | |
| } | |
| pub enum CacheControlTTL { | |
| FiveMinutes, | |
| OneHour, | |
| } | |
| pub struct Citation { | |
| pub citation_type: CitationType, | |
| pub url: String, | |
| pub title: String, | |
| pub encrypted_index: String, | |
| pub cited_text: String, | |
| } | |
| pub enum CitationType { | |
| WebSearchResultLocation, | |
| } | |
| // Delta types for streaming | |
| pub enum MessageContentDeltaType { | |
| TextDelta, | |
| InputJsonDelta, | |
| ThinkingDelta, | |
| SignatureDelta, | |
| CitationsDelta, | |
| } | |
| pub struct MessageContentDelta { | |
| pub delta_type: MessageContentDeltaType, | |
| pub text: Option<String>, | |
| pub partial_json: Option<String>, | |
| pub thinking: Option<String>, | |
| pub signature: Option<String>, | |
| pub citation: Option<Citation>, | |
| } | |
| pub struct Metadata { | |
| pub user_id: Option<String>, | |
| } | |
| pub struct Thinking { | |
| pub thinking_type: ThinkingType, | |
| pub budget_tokens: Option<i32>, | |
| } | |
| pub enum ThinkingType { | |
| Enabled, | |
| Disabled, | |
| Adaptive, | |
| } | |
| pub struct ToolChoice { | |
| pub choice_type: ToolChoiceType, | |
| pub name: Option<String>, | |
| pub disable_parallel_tool_use: bool, | |
| } | |
| fn is_false(v: &bool) -> bool { | |
| !*v | |
| } | |
| pub enum ToolChoiceType { | |
| Tool, | |
| Auto, | |
| None, | |
| Any, | |
| } | |
| pub struct Tool { | |
| pub tool_type: Option<ToolType>, | |
| pub name: String, | |
| pub description: String, | |
| pub input_schema: Option<Box<RawValue>>, | |
| pub cache_control: Option<CacheControl>, | |
| pub max_uses: i32, | |
| pub allowed_domains: Vec<String>, | |
| pub blocked_domains: Vec<String>, | |
| pub user_location: Option<ToolLocation>, | |
| } | |
| fn is_zero_i32(v: &i32) -> bool { | |
| *v == 0 | |
| } | |
| pub enum ToolType { | |
| Custom, | |
| WebSearch2025, | |
| } | |
| pub struct ToolLocation { | |
| pub location_type: ToolLocationType, | |
| pub city: Option<String>, | |
| pub region: Option<String>, | |
| pub country: Option<String>, | |
| pub timezone: Option<String>, | |
| } | |
| pub enum ToolLocationType { | |
| Approximate, | |
| } | |
| pub struct Usage { | |
| pub input_tokens: i64, | |
| pub output_tokens: i64, | |
| pub cache_read_input_tokens: i64, | |
| pub cache_creation_input_tokens: i64, | |
| pub cache_creation: Option<CacheCreationUsage>, | |
| pub server_tool_use: Option<ServerToolUseUsage>, | |
| } | |
| pub struct CacheCreationUsage { | |
| pub ephemeral_5m_input_tokens: i64, | |
| pub ephemeral_1h_input_tokens: i64, | |
| } | |
| pub struct ServerToolUseUsage { | |
| pub web_search_requests: i32, | |
| } | |
| // SSE Event types | |
| pub enum EventType { | |
| Ping, | |
| Error, | |
| MessageStart, | |
| MessageDelta, | |
| MessageStop, | |
| ContentBlockStart, | |
| ContentBlockDelta, | |
| ContentBlockStop, | |
| } | |
| pub enum Event { | |
| Ping, | |
| Error { | |
| error: StreamError, | |
| }, | |
| MessageStart { | |
| message: Message, | |
| }, | |
| MessageDelta { | |
| delta: MessageDelta, | |
| usage: Option<Usage>, | |
| }, | |
| MessageStop, | |
| ContentBlockStart { | |
| index: i32, | |
| content_block: MessageContent, | |
| }, | |
| ContentBlockDelta { | |
| index: i32, | |
| delta: MessageContentDelta, | |
| }, | |
| ContentBlockStop { | |
| index: i32, | |
| }, | |
| } | |
| pub struct MessageDelta { | |
| pub stop_reason: Option<StopReason>, | |
| pub stop_sequence: Option<String>, | |
| } | |
| impl Event { | |
| pub fn event_type(&self) -> EventType { | |
| match self { | |
| Event::Ping => EventType::Ping, | |
| Event::Error { .. } => EventType::Error, | |
| Event::MessageStart { .. } => EventType::MessageStart, | |
| Event::MessageDelta { .. } => EventType::MessageDelta, | |
| Event::MessageStop => EventType::MessageStop, | |
| Event::ContentBlockStart { .. } => EventType::ContentBlockStart, | |
| Event::ContentBlockDelta { .. } => EventType::ContentBlockDelta, | |
| Event::ContentBlockStop { .. } => EventType::ContentBlockStop, | |
| } | |
| } | |
| pub fn to_sse_string(&self) -> String { | |
| let event_name = match self.event_type() { | |
| EventType::Ping => "ping", | |
| EventType::Error => "error", | |
| EventType::MessageStart => "message_start", | |
| EventType::MessageDelta => "message_delta", | |
| EventType::MessageStop => "message_stop", | |
| EventType::ContentBlockStart => "content_block_start", | |
| EventType::ContentBlockDelta => "content_block_delta", | |
| EventType::ContentBlockStop => "content_block_stop", | |
| }; | |
| let data = serde_json::to_string(self).unwrap_or_default(); | |
| format!("event: {}\ndata: {}\n\n", event_name, data) | |
| } | |
| } | |