darkmedia-x-api / engine /src /qdrant_client.rs
cybermedia's picture
Upload folder using huggingface_hub
343eed9 verified
use qdrant_client::{
Qdrant,
config::QdrantConfig,
qdrant::{
PointStruct, SearchPointsBuilder, UpsertPointsBuilder, CreateCollectionBuilder, Distance, VectorParamsBuilder,
},
Payload,
};
use serde::{Deserialize, Serialize};
use anyhow::Result;
#[derive(Debug, Serialize, Deserialize)]
pub struct StoryMetadata {
pub story_id: String,
pub title: String,
pub tags: Vec<String>,
pub category: String,
pub last_updated: u64,
}
pub struct QdrantClientWrapper {
client: Qdrant,
collection_name: String,
}
impl QdrantClientWrapper {
pub async fn new(url: &str, api_key: Option<String>, collection_name: String) -> Result<Self> {
let mut config = QdrantConfig::from_url(url);
if let Some(key) = api_key {
config.api_key = Some(key);
}
let client = Qdrant::new(config)?;
Ok(Self {
client,
collection_name,
})
}
pub async fn search_similar(
&self,
vector: Vec<f32>,
limit: u64,
threshold: f32,
) -> Result<Vec<(StoryMetadata, f32)>> {
let search_result = self.client.search_points(
SearchPointsBuilder::new(self.collection_name.clone(), vector, limit)
.score_threshold(threshold)
.with_payload(true)
).await?;
let mut results = Vec::new();
for point in search_result.result {
let metadata: StoryMetadata = serde_json::from_value(serde_json::to_value(point.payload)?)?;
results.push((metadata, point.score));
}
Ok(results)
}
pub async fn upsert_story(
&self,
id: String,
vector: Vec<f32>,
metadata: StoryMetadata,
) -> Result<()> {
let payload: Payload = serde_json::from_value(serde_json::to_value(metadata)?)?;
let point = PointStruct::new(
id,
vector,
payload,
);
self.client.upsert_points(
UpsertPointsBuilder::new(self.collection_name.clone(), vec![point])
).await?;
Ok(())
}
pub async fn collection_exists(&self) -> Result<bool> {
match self.client.collection_info(self.collection_name.clone()).await {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
pub async fn create_collection(&self, vector_size: u64) -> Result<()> {
self.client.create_collection(
CreateCollectionBuilder::new(self.collection_name.clone())
.vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine))
).await?;
Ok(())
}
}