Spaces:
Running
Running
| from dataclasses import dataclass | |
| from enum import Enum | |
| import torch | |
| from ask_candid.base.lambda_base import LambdaInvokeBase | |
| class Encoding: | |
| """Text encoding vector response | |
| """ | |
| inputs: list[str] | |
| vectors: torch.Tensor | |
| class SummaryItem: | |
| """A single summary object | |
| """ | |
| rank: int | |
| score: float | |
| text: str | |
| class TextSummary: | |
| """Text summarization response | |
| """ | |
| snippets: list[SummaryItem] | |
| def summary(self) -> str: | |
| return ' '.join([_.text for _ in self.snippets]) | |
| class CandidSmallLanguageModel(LambdaInvokeBase): | |
| """Wrapper around Candid's custom small language model. | |
| For more details see https://dev.azure.com/guidestar/DataScience/_git/graph-ai?path=/releases/language. | |
| This services includes: | |
| * text encoding | |
| * document summarization | |
| * entity salience estimation | |
| Parameters | |
| ---------- | |
| access_key : Optional[str], optional | |
| AWS access key, by default None | |
| secret_key : Optional[str], optional | |
| AWS secret key, by default None | |
| """ | |
| class Tasks(Enum): # noqa: D106 | |
| ENCODE = "/encode" | |
| DOCUMENT_SUMMARIZE = "/document/summarize" | |
| DOCUMENT_NER_SALIENCE = "/document/entitySalience" | |
| def __init__( | |
| self, access_key: str | None = None, secret_key: str | None = None | |
| ) -> None: | |
| super().__init__( | |
| function_name="small-lm", | |
| access_key=access_key, | |
| secret_key=secret_key | |
| ) | |
| def encode(self, text: list[str]) -> Encoding: | |
| response = self._submit_request({"text": text, "path": self.Tasks.ENCODE.value}) | |
| assert isinstance(response, dict) | |
| return Encoding( | |
| inputs=(response.get("inputs") or []), | |
| vectors=torch.tensor((response.get("vectors") or []), dtype=torch.float32) | |
| ) | |
| def summarize(self, text: list[str], top_k: int) -> TextSummary: | |
| response = self._submit_request({"text": text, "path": self.Tasks.DOCUMENT_SUMMARIZE.value}) | |
| assert isinstance(response, dict) | |
| return TextSummary( | |
| snippets=[ | |
| SummaryItem(rank=item["rank"], score=item["score"], text=item["value"]) | |
| for item in (response.get("summary") or [])[:top_k] | |
| ] | |
| ) | |