fyerfyer commited on
Commit
f9147ba
·
1 Parent(s): b10e29c

removed rerank part, added hybrid index

Browse files
app.py CHANGED
@@ -2,10 +2,9 @@ import os
2
  import httpx
3
  import gradio as gr
4
  from openai import OpenAI
5
- from qdrant_client import QdrantClient
6
  from sentence_transformers import SentenceTransformer
7
- from flashrank import Ranker, RerankRequest
8
- from types import SimpleNamespace
9
 
10
  API_KEY = os.environ.get('DEEPSEEK_API_KEY')
11
  BASE_URL = "https://api.deepseek.com"
@@ -13,10 +12,12 @@ BASE_URL = "https://api.deepseek.com"
13
  QDRANT_PATH = "./qdrant_db"
14
  COLLECTION_NAME = "huggingface_transformers_docs"
15
  EMBEDDING_MODEL_ID = "fyerfyer/finetune-jina-transformers-v1"
 
16
 
17
  class HFRAG:
18
  def __init__(self):
19
- self.embed_model = SentenceTransformer(EMBEDDING_MODEL_ID, trust_remote_code=True)
 
20
 
21
  lock_file = os.path.join(QDRANT_PATH, ".lock")
22
  if os.path.exists(lock_file):
@@ -42,51 +43,49 @@ class HFRAG:
42
  http_client=httpx.Client(proxy=None, trust_env=False)
43
  )
44
 
45
- self.reranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/tmp")
46
-
47
- def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.40):
48
- query_vector = self.embed_model.encode(query).tolist()
 
 
 
 
 
 
49
 
50
- if hasattr(self.db_client, 'search'):
51
- results = self.db_client.search(
52
- collection_name=COLLECTION_NAME,
53
- query_vector=query_vector,
54
- limit=20, # 扩大召回范围,之后进行重排序
55
- score_threshold=score_threshold
56
- )
57
- else:
58
- results = self.db_client.query_points(
59
- collection_name=COLLECTION_NAME,
60
- query=query_vector,
61
- limit=20,
62
- with_payload=True,
63
- score_threshold=score_threshold
64
- ).points
65
 
66
- passages = [
67
- {"id": result.payload['metadata']['source'], "text": result.payload['text'], "meta": result.payload}
68
- for result in results
69
- ]
70
- rerank_request = RerankRequest(query=query, passages=passages)
71
- reranked_results = self.reranker.rerank(rerank_request)
72
-
73
- # 从重排序后的序列中取出 TopK
74
- final_results = []
75
- for item in reranked_results[:top_k]:
76
- final_result = SimpleNamespace()
77
- final_result.payload = item['meta']
78
- final_result.score = item['score']
79
- final_results.append(final_result)
80
-
81
- return final_results
 
82
 
83
  def format_context(self, search_results):
84
  context_pieces = []
85
  sources_summary = []
86
 
87
  for idx, hit in enumerate(search_results, 1):
88
- raw_source = hit.payload['metadata']['source']
89
- filename = raw_source.split('/')[-1]
90
  text = hit.payload['text']
91
  score = hit.score
92
 
 
2
  import httpx
3
  import gradio as gr
4
  from openai import OpenAI
5
+ from qdrant_client import QdrantClient, models
6
  from sentence_transformers import SentenceTransformer
7
+ from fastembed import SparseTextEmbedding
 
8
 
9
  API_KEY = os.environ.get('DEEPSEEK_API_KEY')
10
  BASE_URL = "https://api.deepseek.com"
 
12
  QDRANT_PATH = "./qdrant_db"
13
  COLLECTION_NAME = "huggingface_transformers_docs"
14
  EMBEDDING_MODEL_ID = "fyerfyer/finetune-jina-transformers-v1"
15
+ SPARSE_MODEL_ID = "prithivida/Splade_PP_en_v1"
16
 
17
  class HFRAG:
18
  def __init__(self):
19
+ self.dense_model = SentenceTransformer(EMBEDDING_MODEL_ID, trust_remote_code=True)
20
+ self.sparse_model = SparseTextEmbedding(model_name=SPARSE_MODEL_ID)
21
 
22
  lock_file = os.path.join(QDRANT_PATH, ".lock")
23
  if os.path.exists(lock_file):
 
43
  http_client=httpx.Client(proxy=None, trust_env=False)
44
  )
45
 
46
+ def retrieve(self, query: str, top_k: int = 5):
47
+ # Generate dense vector
48
+ query_dense_vec = self.dense_model.encode(query).tolist()
49
+
50
+ # Generate sparse vector
51
+ query_sparse_gen = list(self.sparse_model.embed([query]))[0]
52
+ query_sparse_vec = models.SparseVector(
53
+ indices=query_sparse_gen.indices.tolist(),
54
+ values=query_sparse_gen.values.tolist()
55
+ )
56
 
57
+ # Create prefetch for dense retrieval
58
+ prefetch_dense = models.Prefetch(
59
+ query=query_dense_vec,
60
+ using="text-dense",
61
+ limit=20,
62
+ )
 
 
 
 
 
 
 
 
 
63
 
64
+ # Create prefetch for sparse retrieval
65
+ prefetch_sparse = models.Prefetch(
66
+ query=query_sparse_vec,
67
+ using="text-sparse",
68
+ limit=20,
69
+ )
70
+
71
+ # Hybrid search with RRF fusion
72
+ results = self.db_client.query_points(
73
+ collection_name=COLLECTION_NAME,
74
+ prefetch=[prefetch_dense, prefetch_sparse],
75
+ query=models.FusionQuery(fusion=models.Fusion.RRF),
76
+ limit=top_k,
77
+ with_payload=True
78
+ ).points
79
+
80
+ return results
81
 
82
  def format_context(self, search_results):
83
  context_pieces = []
84
  sources_summary = []
85
 
86
  for idx, hit in enumerate(search_results, 1):
87
+ raw_source = hit.payload.get('source', 'unknown')
88
+ filename = raw_source.split('/')[-1] if '/' in raw_source else raw_source
89
  text = hit.payload['text']
90
  score = hit.score
91
 
qdrant_db/collection/huggingface_transformers_docs/storage.sqlite CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:88a55f2d047299d73d59f44f05d0ef0bf03ca865ae5dbd5523eed72269cb0f98
3
- size 56549376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:767ba990e571262333521d2528e9d57f248e9cd63f6269907716bea20617c607
3
+ size 62464000
qdrant_db/meta.json CHANGED
@@ -1 +1 @@
1
- {"collections": {"huggingface_transformers_docs": {"vectors": {"size": 768, "distance": "Cosine", "hnsw_config": null, "quantization_config": null, "on_disk": null, "datatype": null, "multivector_config": null}, "shard_number": null, "sharding_method": null, "replication_factor": null, "write_consistency_factor": null, "on_disk_payload": null, "hnsw_config": null, "wal_config": null, "optimizers_config": null, "quantization_config": null, "sparse_vectors": null, "strict_mode_config": null, "metadata": null}}, "aliases": {}}
 
1
+ {"collections": {"huggingface_transformers_docs": {"vectors": {"text-dense": {"size": 768, "distance": "Cosine", "hnsw_config": null, "quantization_config": null, "on_disk": null, "datatype": null, "multivector_config": null}}, "shard_number": null, "sharding_method": null, "replication_factor": null, "write_consistency_factor": null, "on_disk_payload": null, "hnsw_config": null, "wal_config": null, "optimizers_config": null, "quantization_config": null, "sparse_vectors": {"text-sparse": {"index": {"full_scan_threshold": null, "on_disk": true, "datatype": null}, "modifier": null}}, "strict_mode_config": null, "metadata": null}}, "aliases": {}}
requirements.txt CHANGED
@@ -5,4 +5,4 @@ sentence-transformers
5
  httpx
6
  torch
7
  python-dotenv
8
- flashrank
 
5
  httpx
6
  torch
7
  python-dotenv
8
+ fastembed