Spaces:
Runtime error
Runtime error
fix: update qdrant_client_helper to be async and add collection creation logic
Browse files- src/models/states/mod.rs +20 -4
src/models/states/mod.rs
CHANGED
|
@@ -10,10 +10,13 @@ use deadpool_diesel::{
|
|
| 10 |
postgres::{Manager, Pool},
|
| 11 |
};
|
| 12 |
use hf_hub::api::tokio::ApiRepo;
|
| 13 |
-
use qdrant_client::
|
|
|
|
|
|
|
|
|
|
| 14 |
use tokenizers::Tokenizer;
|
| 15 |
|
| 16 |
-
use crate::constants::EMBEDDING_MODEL_NAME;
|
| 17 |
|
| 18 |
#[derive(Clone)]
|
| 19 |
pub struct AppState {
|
|
@@ -31,7 +34,7 @@ impl AppState {
|
|
| 31 |
|
| 32 |
Ok(Self {
|
| 33 |
diesel_pool: Self::diesel_pool_helper()?,
|
| 34 |
-
qdrant_client: Self::qdrant_client_helper()?,
|
| 35 |
model: Self::model_helper(&api, &device).await?,
|
| 36 |
tokenizer: Self::tokenizer_helper(&api).await?,
|
| 37 |
device: Arc::new(device),
|
|
@@ -46,8 +49,21 @@ impl AppState {
|
|
| 46 |
.build()?)
|
| 47 |
}
|
| 48 |
|
| 49 |
-
fn qdrant_client_helper() -> anyhow::Result<Arc<Qdrant>> {
|
| 50 |
let client = Qdrant::from_url(&std::env::var("QDRANT_URL")?).build()?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
Ok(Arc::new(client))
|
| 52 |
}
|
| 53 |
|
|
|
|
| 10 |
postgres::{Manager, Pool},
|
| 11 |
};
|
| 12 |
use hf_hub::api::tokio::ApiRepo;
|
| 13 |
+
use qdrant_client::{
|
| 14 |
+
Qdrant,
|
| 15 |
+
qdrant::{CreateCollectionBuilder, Distance, VectorParamsBuilder},
|
| 16 |
+
};
|
| 17 |
use tokenizers::Tokenizer;
|
| 18 |
|
| 19 |
+
use crate::constants::{EMBEDDING_MODEL_NAME, QDRANT_KEYFRAME_COLLECTION_NAME};
|
| 20 |
|
| 21 |
#[derive(Clone)]
|
| 22 |
pub struct AppState {
|
|
|
|
| 34 |
|
| 35 |
Ok(Self {
|
| 36 |
diesel_pool: Self::diesel_pool_helper()?,
|
| 37 |
+
qdrant_client: Self::qdrant_client_helper().await?,
|
| 38 |
model: Self::model_helper(&api, &device).await?,
|
| 39 |
tokenizer: Self::tokenizer_helper(&api).await?,
|
| 40 |
device: Arc::new(device),
|
|
|
|
| 49 |
.build()?)
|
| 50 |
}
|
| 51 |
|
| 52 |
+
async fn qdrant_client_helper() -> anyhow::Result<Arc<Qdrant>> {
|
| 53 |
let client = Qdrant::from_url(&std::env::var("QDRANT_URL")?).build()?;
|
| 54 |
+
|
| 55 |
+
if client
|
| 56 |
+
.collection_exists(QDRANT_KEYFRAME_COLLECTION_NAME)
|
| 57 |
+
.await?
|
| 58 |
+
{
|
| 59 |
+
client
|
| 60 |
+
.create_collection(
|
| 61 |
+
CreateCollectionBuilder::new(QDRANT_KEYFRAME_COLLECTION_NAME)
|
| 62 |
+
.vectors_config(VectorParamsBuilder::new(512, Distance::Cosine)),
|
| 63 |
+
)
|
| 64 |
+
.await?;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
Ok(Arc::new(client))
|
| 68 |
}
|
| 69 |
|