File size: 1,777 Bytes
6b23458
 
79bbee8
6b23458
 
68fc6df
6b23458
 
68fc6df
6b23458
 
 
 
 
 
 
68fc6df
6b23458
 
 
 
 
 
68fc6df
6b23458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68fc6df
 
6b23458
 
 
 
 
 
79bbee8
6b23458
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
use qdrant_client::{
    Qdrant,
    qdrant::{Query, QueryPointsBuilder},
};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use translators::{GoogleTranslator, Translator};

use crate::{
    constants::QDRANT_KEYFRAME_COLLECTION_NAME,
    models::{dtos::vectors::keyframes::VectorizedKeyframeDto, states::AppState},
    services::embeddings::texts::TextEmbeddingService,
};

#[derive(Clone, Copy)]
pub struct VectorizedKeyframeService<'a> {
    client: &'a Qdrant,
    translator: &'a GoogleTranslator,
}

impl<'a> From<&'a AppState> for VectorizedKeyframeService<'a> {
    fn from(value: &'a AppState) -> Self {
        Self {
            client: value.qdrant_client(),
            translator: value.translator(),
        }
    }
}

impl<'a> VectorizedKeyframeService<'a> {
    async fn embed_text(&self, text: &str) -> anyhow::Result<Vec<f32>> {
        let mut embedding_service = TextEmbeddingService::new().await?;
        embedding_service.embed_text(text).await
    }

    pub async fn find_nearest_top_k_by_text(
        &self,
        text: &str,
        top_k: u64,
    ) -> anyhow::Result<Vec<VectorizedKeyframeDto>> {
        let translated_text = self.translator.translate_async(text, "vi", "en").await?;
        let embeddings = self.embed_text(&translated_text).await?;

        let query_result = self
            .client
            .query(
                QueryPointsBuilder::new(QDRANT_KEYFRAME_COLLECTION_NAME)
                    .query(Query::new_nearest(embeddings))
                    .using("images")
                    .limit(top_k),
            )
            .await?
            .result;

        query_result
            .into_par_iter()
            .map(VectorizedKeyframeDto::try_from)
            .collect::<Result<_, _>>()
    }
}