FD900 commited on
Commit
de1b9f0
·
verified ·
1 Parent(s): 9b2def5

Update tools/content_retriever_tool.py

Browse files
Files changed (1) hide show
  1. tools/content_retriever_tool.py +58 -0
tools/content_retriever_tool.py CHANGED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from docling.document_converter import DocumentConverter
3
+ from docling.chunking import HierarchicalChunker
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import torch
6
+
7
+ class ContentRetrievalTool(Tool):
8
+ name = 'content_retrieval'
9
+ description = """Extracts and summarizes relevant content from webpages or documents. Supports formats like PDF, DOCX, HTML, XLSX, etc."""
10
+
11
+ inputs = {
12
+ "url": {
13
+ "type": "string",
14
+ "description": "The path or web link to the file or page to process."
15
+ },
16
+ "query": {
17
+ "type": "string",
18
+ "description": "Main subject or keyword to retrieve from the content."
19
+ },
20
+ }
21
+ output_type = "string"
22
+
23
+ def __init__(self, model_name: str = 'all-MiniLM-L6-v2', threshold: float = 0.2, **kwargs):
24
+ super().__init__(**kwargs)
25
+ self.threshold = threshold
26
+ self._converter = DocumentConverter()
27
+ self._chunker = HierarchicalChunker()
28
+ self._model = SentenceTransformer(model_name)
29
+
30
+ def forward(self, url: str, query: str) -> str:
31
+ document = self._converter.convert(url).document
32
+ if not document:
33
+ return "Failed to load content."
34
+
35
+ segments = list(self._chunker.chunk(document))
36
+ if not segments:
37
+ return "No content detected."
38
+
39
+ segment_texts = [seg.text for seg in segments]
40
+ segment_contexts = [self._chunker.contextualize(seg).replace(seg.text, "").strip() for seg in segments]
41
+
42
+ all_embeddings = [
43
+ self._model.encode(segment_texts, convert_to_tensor=True),
44
+ self._model.encode(segment_contexts, convert_to_tensor=True)
45
+ ]
46
+ query_emb = self._model.encode([s.strip() for s in query.split(',') if s.strip()], convert_to_tensor=True)
47
+
48
+ idx = set()
49
+ for emb in all_embeddings:
50
+ for similarity in util.pytorch_cos_sim(query_emb, emb):
51
+ probs = torch.nn.functional.softmax(similarity, dim=0)
52
+ for i in torch.argsort(probs, descending=True):
53
+ idx.add(i.item())
54
+ if probs[i] >= self.threshold:
55
+ break
56
+
57
+ selected = sorted(list(idx))
58
+ return '\n\n'.join([self._chunker.contextualize(segments[i]) for i in selected]) if selected else "No relevant info found."