File size: 7,878 Bytes
cf57c77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 |
use async_openai::{
config::{AzureConfig, OpenAIConfig},
error::OpenAIError,
types::{
ChatCompletionRequestMessageArgs, CreateChatCompletionRequestArgs,
CreateChatCompletionResponse, CreateEmbeddingRequestArgs, Role,
},
Client,
};
use async_trait::async_trait;
use deadpool::managed::{Manager, RecycleResult};
use futures_util::future::try_join_all;
use redis::{FromRedisValue, RedisError, Value};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct OpenAIClientManager {}
#[async_trait]
impl Manager for OpenAIClientManager {
type Type = AnyOpenAIClient;
type Error = MotorheadError;
async fn create(&self) -> Result<AnyOpenAIClient, MotorheadError> {
let openai_client = match (
env::var("AZURE_API_KEY"),
env::var("AZURE_DEPLOYMENT_ID"),
env::var("AZURE_API_BASE"),
) {
(Ok(azure_api_key), Ok(azure_deployment_id), Ok(azure_api_base)) => {
let config = AzureConfig::new()
.with_api_base(azure_api_base)
.with_api_key(azure_api_key)
.with_deployment_id(azure_deployment_id)
.with_api_version("2023-05-15");
AnyOpenAIClient::Azure(Client::with_config(config))
}
_ => AnyOpenAIClient::OpenAI(Client::new()),
};
Ok(openai_client)
}
async fn recycle(&self, _: &mut AnyOpenAIClient) -> RecycleResult<MotorheadError> {
Ok(())
}
}
pub enum AnyOpenAIClient {
Azure(Client<AzureConfig>),
OpenAI(Client<OpenAIConfig>),
}
impl AnyOpenAIClient {
pub async fn create_chat_completion(
&self,
model: &str,
progresive_prompt: &str,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
let request = CreateChatCompletionRequestArgs::default()
.max_tokens(512u16)
.model(model)
.messages([ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(progresive_prompt)
.build()?])
.build()?;
match self {
AnyOpenAIClient::Azure(client) => client.chat().create(request).await,
AnyOpenAIClient::OpenAI(client) => client.chat().create(request).await,
}
}
pub async fn create_embedding(
&self,
query_vec: Vec<String>,
) -> Result<Vec<Vec<f32>>, OpenAIError> {
match self {
AnyOpenAIClient::OpenAI(client) => {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-ada-002")
.input(query_vec)
.build()?;
let response = client.embeddings().create(request).await?;
let embeddings: Vec<_> = response
.data
.iter()
.map(|data| data.embedding.clone())
.collect();
Ok(embeddings)
}
AnyOpenAIClient::Azure(client) => {
let tasks: Vec<_> = query_vec
.into_iter()
.map(|query| async {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-ada-002")
.input(vec![query])
.build()?;
client.embeddings().create(request).await
})
.collect();
let responses: Result<Vec<_>, _> = try_join_all(tasks).await;
match responses {
Ok(successful_responses) => {
let embeddings: Vec<_> = successful_responses
.into_iter()
.flat_map(|response| response.data.into_iter())
.map(|data| data.embedding)
.collect();
Ok(embeddings)
}
Err(err) => Err(err),
}
}
}
}
}
pub struct AppState {
pub window_size: i64,
pub session_cleanup: Arc<Mutex<HashMap<String, bool>>>,
pub openai_pool: deadpool::managed::Pool<OpenAIClientManager>,
pub long_term_memory: bool,
pub model: String,
}
#[derive(Serialize, Deserialize)]
pub struct SearchPayload {
pub text: String,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct MemoryMessage {
pub role: String,
pub content: String,
}
#[derive(Deserialize)]
pub struct MemoryMessagesAndContext {
pub messages: Vec<MemoryMessage>,
pub context: Option<String>,
}
#[derive(Serialize)]
pub struct MemoryResponse {
pub messages: Vec<MemoryMessage>,
pub context: Option<String>,
pub tokens: Option<i64>,
}
#[derive(Serialize)]
pub struct HealthCheckResponse {
pub now: u128,
}
#[derive(Serialize)]
pub struct AckResponse {
pub status: &'static str,
}
#[derive(Debug)]
pub enum MotorheadError {
RedisError(RedisError),
IncrementalSummarizationError(String),
}
impl std::fmt::Display for MotorheadError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
MotorheadError::RedisError(e) => write!(f, "Redis error: {}", e),
MotorheadError::IncrementalSummarizationError(e) => {
write!(f, "Incremental summarization error: {}", e)
}
}
}
}
impl From<Box<dyn Error + Send + Sync>> for MotorheadError {
fn from(error: Box<dyn Error + Send + Sync>) -> Self {
MotorheadError::IncrementalSummarizationError(error.to_string())
}
}
impl From<RedisError> for MotorheadError {
fn from(err: RedisError) -> Self {
MotorheadError::RedisError(err)
}
}
impl std::error::Error for MotorheadError {}
#[derive(Serialize, Deserialize, Debug)]
pub struct RedisearchResult {
pub role: String,
pub content: String,
pub dist: f64,
}
impl FromRedisValue for RedisearchResult {
fn from_redis_value(v: &Value) -> redis::RedisResult<Self> {
let values: Vec<String> = redis::from_redis_value(v)?;
let mut content = String::new();
let mut role = String::new();
let mut dist = 0.0;
for i in 0..values.len() {
match values[i].as_str() {
"content" => content = values[i + 1].clone(),
"role" => role = values[i + 1].clone(),
"dist" => dist = values[i + 1].parse::<f64>().unwrap_or(0.0),
_ => continue,
}
}
Ok(RedisearchResult {
role,
content,
dist,
})
}
}
pub fn parse_redisearch_response(response: &Value) -> Vec<RedisearchResult> {
match response {
Value::Bulk(array) => {
let mut results = Vec::new();
let n = array.len();
for item in array.iter().take(n).skip(1) {
if let Value::Bulk(ref bulk) = item {
if let Ok(result) =
RedisearchResult::from_redis_value(&Value::Bulk(bulk.clone()))
{
results.push(result);
}
}
}
results
}
_ => vec![],
}
}
#[derive(serde::Deserialize)]
pub struct NamespaceQuery {
pub namespace: Option<String>,
}
#[derive(serde::Deserialize)]
pub struct GetSessionsQuery {
#[serde(default = "default_page")]
pub page: usize,
#[serde(default = "default_size")]
pub size: usize,
pub namespace: Option<String>,
}
fn default_page() -> usize {
1
}
fn default_size() -> usize {
10
}
|