Spaces:
Sleeping
Sleeping
| use qdrant_client::{ | |
| Qdrant, | |
| config::QdrantConfig, | |
| qdrant::{ | |
| PointStruct, SearchPointsBuilder, UpsertPointsBuilder, CreateCollectionBuilder, Distance, VectorParamsBuilder, | |
| }, | |
| Payload, | |
| }; | |
| use serde::{Deserialize, Serialize}; | |
| use anyhow::Result; | |
| pub struct StoryMetadata { | |
| pub story_id: String, | |
| pub title: String, | |
| pub tags: Vec<String>, | |
| pub category: String, | |
| pub last_updated: u64, | |
| } | |
| pub struct QdrantClientWrapper { | |
| client: Qdrant, | |
| collection_name: String, | |
| } | |
| impl QdrantClientWrapper { | |
| pub async fn new(url: &str, api_key: Option<String>, collection_name: String) -> Result<Self> { | |
| let mut config = QdrantConfig::from_url(url); | |
| if let Some(key) = api_key { | |
| config.api_key = Some(key); | |
| } | |
| let client = Qdrant::new(config)?; | |
| Ok(Self { | |
| client, | |
| collection_name, | |
| }) | |
| } | |
| pub async fn search_similar( | |
| &self, | |
| vector: Vec<f32>, | |
| limit: u64, | |
| threshold: f32, | |
| ) -> Result<Vec<(StoryMetadata, f32)>> { | |
| let search_result = self.client.search_points( | |
| SearchPointsBuilder::new(self.collection_name.clone(), vector, limit) | |
| .score_threshold(threshold) | |
| .with_payload(true) | |
| ).await?; | |
| let mut results = Vec::new(); | |
| for point in search_result.result { | |
| let metadata: StoryMetadata = serde_json::from_value(serde_json::to_value(point.payload)?)?; | |
| results.push((metadata, point.score)); | |
| } | |
| Ok(results) | |
| } | |
| pub async fn upsert_story( | |
| &self, | |
| id: String, | |
| vector: Vec<f32>, | |
| metadata: StoryMetadata, | |
| ) -> Result<()> { | |
| let payload: Payload = serde_json::from_value(serde_json::to_value(metadata)?)?; | |
| let point = PointStruct::new( | |
| id, | |
| vector, | |
| payload, | |
| ); | |
| self.client.upsert_points( | |
| UpsertPointsBuilder::new(self.collection_name.clone(), vec![point]) | |
| ).await?; | |
| Ok(()) | |
| } | |
| pub async fn collection_exists(&self) -> Result<bool> { | |
| match self.client.collection_info(self.collection_name.clone()).await { | |
| Ok(_) => Ok(true), | |
| Err(_) => Ok(false), | |
| } | |
| } | |
| pub async fn create_collection(&self, vector_size: u64) -> Result<()> { | |
| self.client.create_collection( | |
| CreateCollectionBuilder::new(self.collection_name.clone()) | |
| .vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)) | |
| ).await?; | |
| Ok(()) | |
| } | |
| } | |