Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from .utils import get_md5 | |
| class ToolRAGModel: | |
| def __init__(self, rag_model_name): | |
| self.rag_model_name = rag_model_name | |
| self.rag_model = None | |
| self.tool_desc_embedding = None | |
| self.tool_name = None | |
| self.tool_embedding_path = None | |
| self.load_rag_model() | |
| def load_rag_model(self): | |
| self.rag_model = SentenceTransformer(self.rag_model_name) | |
| self.rag_model.max_seq_length = 4096 | |
| self.rag_model.tokenizer.padding_side = "right" | |
| def load_tool_desc_embedding(self, toolbox): | |
| self.tool_name, _ = toolbox.refresh_tool_name_desc(enable_full_desc=True) | |
| all_tools_str = [json.dumps(each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)] | |
| md5_value = get_md5(str(all_tools_str)) | |
| print("Computed MD5 for tool embedding:", md5_value) | |
| self.tool_embedding_path = os.path.join( | |
| os.path.dirname(__file__), | |
| self.rag_model_name.split("/")[-1] + f"_tool_embedding_{md5_value}.pt" | |
| ) | |
| if os.path.exists(self.tool_embedding_path): | |
| try: | |
| self.tool_desc_embedding = torch.load(self.tool_embedding_path, map_location="cpu") | |
| assert len(self.tool_desc_embedding) == len(toolbox.all_tools), \ | |
| "Tool count mismatch with loaded embeddings." | |
| print("\033[92mLoaded cached tool_desc_embedding.\033[0m") | |
| return | |
| except Exception as e: | |
| print(f"⚠️ Failed loading cached embeddings: {e}") | |
| self.tool_desc_embedding = None | |
| print("\033[93mGenerating new tool_desc_embedding...\033[0m") | |
| self.tool_desc_embedding = self.rag_model.encode( | |
| all_tools_str, prompt="", normalize_embeddings=True | |
| ) | |
| torch.save(self.tool_desc_embedding, self.tool_embedding_path) | |
| print(f"\033[92mSaved new tool_desc_embedding to {self.tool_embedding_path}\033[0m") | |
| def rag_infer(self, query, top_k=5): | |
| torch.cuda.empty_cache() | |
| queries = [query] | |
| query_embeddings = self.rag_model.encode( | |
| queries, prompt="", normalize_embeddings=True | |
| ) | |
| if self.tool_desc_embedding is None: | |
| raise RuntimeError("❌ tool_desc_embedding is not initialized. Did you forget to call load_tool_desc_embedding()?") | |
| scores = self.rag_model.similarity( | |
| query_embeddings, self.tool_desc_embedding | |
| ) | |
| top_k = min(top_k, len(self.tool_name)) | |
| top_k_indices = torch.topk(scores, top_k).indices.tolist()[0] | |
| top_k_tool_names = [self.tool_name[i] for i in top_k_indices] | |
| return top_k_tool_names | |