Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Any | |
| import re | |
| from loguru import logger | |
| from rag_demo.preprocessing.embed import EmbeddedChunk | |
| from transformers import pipeline | |
| class SourceAnnotator: | |
| def __init__(self): | |
| # Extractive question answering model | |
| self.source_annotator = pipeline( | |
| "question-answering", | |
| model="distilbert/distilbert-base-cased-distilled-squad", | |
| ) | |
| def annotate(self, response: str, reranked_chunks: list[EmbeddedChunk]) -> str: | |
| sentences = self.split_sentences(response) | |
| annotated_response = "" | |
| for sentence in sentences: | |
| scores = [] | |
| for chunk in reranked_chunks: | |
| score = self.source_annotator(sentence, chunk.content) | |
| score["filename"] = chunk.metadata["filename"].split(".pdf")[0] | |
| score["chunk_id"] = chunk.chunk_id | |
| scores.append(score) | |
| # Could also use a score cut-off instead of max() | |
| max_score = max(scores, key=lambda x: x["score"]) | |
| annotated_response += f"{sentence} [filename: {max_score['filename']}, chunk_id: {max_score['chunk_id']} " | |
| return annotated_response | |
| def split_sentences(self, text: str) -> list[str]: | |
| pattern = r"(?<=[.!?])\s+(?=[A-Z])" | |
| sentences = re.split(pattern, text) | |
| return [s.strip() for s in sentences if s.strip()] | |