| from sentence_transformers import SentenceTransformer, util |
| import chromadb |
| from chromadb.utils import embedding_functions |
| import numpy as np |
|
|
| class SimpleRetriever: |
| def __init__(self): |
| self.encoder = SentenceTransformer('all-MiniLM-L6-v2') |
| self.client = chromadb.Client() |
| self.collection = self.client.create_collection( |
| name="incidents", |
| embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction() |
| ) |
| self._seed_incidents() |
|
|
| def _seed_incidents(self): |
| incidents = [ |
| ("High latency in payment service, caused by database connection pool exhaustion.", "database_pool"), |
| ("Memory leak in API gateway after 24 hours of uptime.", "memory_leak"), |
| ("Authentication service returning 500 errors due to misconfigured OAuth.", "oauth_config"), |
| ("Disk full on logging node, causing log loss.", "disk_full"), |
| ] |
| for text, cause in incidents: |
| self.collection.add( |
| documents=[text], |
| metadatas=[{"cause": cause}], |
| ids=[cause] |
| ) |
|
|
| def get_similarity(self, query: str) -> float: |
| results = self.collection.query(query_texts=[query], n_results=1) |
| if results['distances'] and len(results['distances'][0]) > 0: |
| |
| return 1.0 / (1.0 + results['distances'][0][0]) |
| return 0.0 |