Spaces:
Sleeping
Sleeping
Update tools/content_retriever_tool.py
Browse files
tools/content_retriever_tool.py
CHANGED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from smolagents import Tool
|
| 2 |
+
from docling.document_converter import DocumentConverter
|
| 3 |
+
from docling.chunking import HierarchicalChunker
|
| 4 |
+
from sentence_transformers import SentenceTransformer, util
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
class ContentRetrievalTool(Tool):
|
| 8 |
+
name = 'content_retrieval'
|
| 9 |
+
description = """Extracts and summarizes relevant content from webpages or documents. Supports formats like PDF, DOCX, HTML, XLSX, etc."""
|
| 10 |
+
|
| 11 |
+
inputs = {
|
| 12 |
+
"url": {
|
| 13 |
+
"type": "string",
|
| 14 |
+
"description": "The path or web link to the file or page to process."
|
| 15 |
+
},
|
| 16 |
+
"query": {
|
| 17 |
+
"type": "string",
|
| 18 |
+
"description": "Main subject or keyword to retrieve from the content."
|
| 19 |
+
},
|
| 20 |
+
}
|
| 21 |
+
output_type = "string"
|
| 22 |
+
|
| 23 |
+
def __init__(self, model_name: str = 'all-MiniLM-L6-v2', threshold: float = 0.2, **kwargs):
|
| 24 |
+
super().__init__(**kwargs)
|
| 25 |
+
self.threshold = threshold
|
| 26 |
+
self._converter = DocumentConverter()
|
| 27 |
+
self._chunker = HierarchicalChunker()
|
| 28 |
+
self._model = SentenceTransformer(model_name)
|
| 29 |
+
|
| 30 |
+
def forward(self, url: str, query: str) -> str:
|
| 31 |
+
document = self._converter.convert(url).document
|
| 32 |
+
if not document:
|
| 33 |
+
return "Failed to load content."
|
| 34 |
+
|
| 35 |
+
segments = list(self._chunker.chunk(document))
|
| 36 |
+
if not segments:
|
| 37 |
+
return "No content detected."
|
| 38 |
+
|
| 39 |
+
segment_texts = [seg.text for seg in segments]
|
| 40 |
+
segment_contexts = [self._chunker.contextualize(seg).replace(seg.text, "").strip() for seg in segments]
|
| 41 |
+
|
| 42 |
+
all_embeddings = [
|
| 43 |
+
self._model.encode(segment_texts, convert_to_tensor=True),
|
| 44 |
+
self._model.encode(segment_contexts, convert_to_tensor=True)
|
| 45 |
+
]
|
| 46 |
+
query_emb = self._model.encode([s.strip() for s in query.split(',') if s.strip()], convert_to_tensor=True)
|
| 47 |
+
|
| 48 |
+
idx = set()
|
| 49 |
+
for emb in all_embeddings:
|
| 50 |
+
for similarity in util.pytorch_cos_sim(query_emb, emb):
|
| 51 |
+
probs = torch.nn.functional.softmax(similarity, dim=0)
|
| 52 |
+
for i in torch.argsort(probs, descending=True):
|
| 53 |
+
idx.add(i.item())
|
| 54 |
+
if probs[i] >= self.threshold:
|
| 55 |
+
break
|
| 56 |
+
|
| 57 |
+
selected = sorted(list(idx))
|
| 58 |
+
return '\n\n'.join([self._chunker.contextualize(segments[i]) for i in selected]) if selected else "No relevant info found."
|