FD900 commited on
Commit
06f1955
·
verified ·
1 Parent(s): 5009282

Update tools/content_retriever_tool.py

Browse files
Files changed (1) hide show
  1. tools/content_retriever_tool.py +40 -35
tools/content_retriever_tool.py CHANGED
@@ -1,58 +1,63 @@
1
- from smolagents import Tool
2
  from docling.document_converter import DocumentConverter
3
  from docling.chunking import HierarchicalChunker
4
  from sentence_transformers import SentenceTransformer, util
5
  import torch
6
 
7
- class ContentRetrievalTool(Tool):
8
- name = 'content_retrieval'
9
- description = """Extracts and summarizes relevant content from webpages or documents. Supports formats like PDF, DOCX, HTML, XLSX, etc."""
10
 
11
  inputs = {
12
  "url": {
13
  "type": "string",
14
- "description": "The path or web link to the file or page to process."
15
  },
16
  "query": {
17
  "type": "string",
18
- "description": "Main subject or keyword to retrieve from the content."
19
  },
20
  }
21
  output_type = "string"
22
 
23
- def __init__(self, model_name: str = 'all-MiniLM-L6-v2', threshold: float = 0.2, **kwargs):
24
- super().__init__(**kwargs)
25
  self.threshold = threshold
26
  self._converter = DocumentConverter()
27
  self._chunker = HierarchicalChunker()
28
- self._model = SentenceTransformer(model_name)
 
29
 
30
  def forward(self, url: str, query: str) -> str:
31
- document = self._converter.convert(url).document
32
- if not document:
33
- return "Failed to load content."
34
-
35
- segments = list(self._chunker.chunk(document))
36
- if not segments:
37
- return "No content detected."
38
-
39
- segment_texts = [seg.text for seg in segments]
40
- segment_contexts = [self._chunker.contextualize(seg).replace(seg.text, "").strip() for seg in segments]
41
-
42
- all_embeddings = [
43
- self._model.encode(segment_texts, convert_to_tensor=True),
44
- self._model.encode(segment_contexts, convert_to_tensor=True)
45
- ]
46
- query_emb = self._model.encode([s.strip() for s in query.split(',') if s.strip()], convert_to_tensor=True)
47
-
48
- idx = set()
49
- for emb in all_embeddings:
50
- for similarity in util.pytorch_cos_sim(query_emb, emb):
51
- probs = torch.nn.functional.softmax(similarity, dim=0)
52
- for i in torch.argsort(probs, descending=True):
53
- idx.add(i.item())
54
- if probs[i] >= self.threshold:
 
 
55
  break
56
 
57
- selected = sorted(list(idx))
58
- return '\n\n'.join([self._chunker.contextualize(segments[i]) for i in selected]) if selected else "No relevant info found."
 
 
 
 
1
+ from tools.base import Tool
2
  from docling.document_converter import DocumentConverter
3
  from docling.chunking import HierarchicalChunker
4
  from sentence_transformers import SentenceTransformer, util
5
  import torch
6
 
7
+ class ContentRetrieverTool(Tool):
8
+ name = 'retrieve_content'
9
+ description = "Extracts relevant content from a file or URL (PDF, DOCX, XLSX, HTML, etc.) based on a given query."
10
 
11
  inputs = {
12
  "url": {
13
  "type": "string",
14
+ "description": "The document URL or local path to load content from.",
15
  },
16
  "query": {
17
  "type": "string",
18
+ "description": "Query term(s) used to filter relevant content from the document.",
19
  },
20
  }
21
  output_type = "string"
22
 
23
+ def __init__(self, model_name: str = 'all-MiniLM-L6-v2', threshold: float = 0.2):
 
24
  self.threshold = threshold
25
  self._converter = DocumentConverter()
26
  self._chunker = HierarchicalChunker()
27
+ self._embedder = SentenceTransformer(model_name)
28
+ super().__init__()
29
 
30
  def forward(self, url: str, query: str) -> str:
31
+ doc = self._converter.convert(url).document
32
+ chunks = list(self._chunker.chunk(dl_doc=doc))
33
+ if not chunks:
34
+ return "No content found."
35
+
36
+ texts = [chunk.text for chunk in chunks]
37
+ contextual_chunks = [self._chunker.contextualize(c) for c in chunks]
38
+ context_texts = [ctx.replace(txt, "").strip() for txt, ctx in zip(texts, contextual_chunks)]
39
+
40
+ query_embedding = self._embedder.encode(
41
+ [q.strip() for q in query.split(",") if q.strip()],
42
+ convert_to_tensor=True,
43
+ )
44
+
45
+ matches = set()
46
+ for corpus in [texts, context_texts]:
47
+ embeddings = self._embedder.encode(corpus, convert_to_tensor=True)
48
+ for score in util.pytorch_cos_sim(query_embedding, embeddings):
49
+ probs = torch.nn.functional.softmax(score, dim=0)
50
+ sorted_idxs = torch.argsort(probs, descending=True)
51
+
52
+ cum_prob = 0.0
53
+ for idx in sorted_idxs:
54
+ cum_prob += probs[idx].item()
55
+ matches.add(idx.item())
56
+ if cum_prob >= self.threshold:
57
  break
58
 
59
+ if not matches:
60
+ return "No relevant chunks found."
61
+
62
+ selected_chunks = [contextual_chunks[i] for i in sorted(matches)]
63
+ return "\n\n".join(selected_chunks)