Spaces:
Sleeping
Sleeping
| from smolagents import Tool | |
| from docling.document_converter import DocumentConverter | |
| from docling.chunking import HierarchicalChunker | |
| from sentence_transformers import SentenceTransformer, util | |
| import torch | |
| class ContentRetrievalTool(Tool): | |
| name = 'content_retrieval' | |
| description = """Extracts and summarizes relevant content from webpages or documents. Supports formats like PDF, DOCX, HTML, XLSX, etc.""" | |
| inputs = { | |
| "url": { | |
| "type": "string", | |
| "description": "The path or web link to the file or page to process." | |
| }, | |
| "query": { | |
| "type": "string", | |
| "description": "Main subject or keyword to retrieve from the content." | |
| }, | |
| } | |
| output_type = "string" | |
| def __init__(self, model_name: str = 'all-MiniLM-L6-v2', threshold: float = 0.2, **kwargs): | |
| super().__init__(**kwargs) | |
| self.threshold = threshold | |
| self._converter = DocumentConverter() | |
| self._chunker = HierarchicalChunker() | |
| self._model = SentenceTransformer(model_name) | |
| def forward(self, url: str, query: str) -> str: | |
| document = self._converter.convert(url).document | |
| if not document: | |
| return "Failed to load content." | |
| segments = list(self._chunker.chunk(document)) | |
| if not segments: | |
| return "No content detected." | |
| segment_texts = [seg.text for seg in segments] | |
| segment_contexts = [self._chunker.contextualize(seg).replace(seg.text, "").strip() for seg in segments] | |
| all_embeddings = [ | |
| self._model.encode(segment_texts, convert_to_tensor=True), | |
| self._model.encode(segment_contexts, convert_to_tensor=True) | |
| ] | |
| query_emb = self._model.encode([s.strip() for s in query.split(',') if s.strip()], convert_to_tensor=True) | |
| idx = set() | |
| for emb in all_embeddings: | |
| for similarity in util.pytorch_cos_sim(query_emb, emb): | |
| probs = torch.nn.functional.softmax(similarity, dim=0) | |
| for i in torch.argsort(probs, descending=True): | |
| idx.add(i.item()) | |
| if probs[i] >= self.threshold: | |
| break | |
| selected = sorted(list(idx)) | |
| return '\n\n'.join([self._chunker.contextualize(segments[i]) for i in selected]) if selected else "No relevant info found." |