Update app.py
Browse files
app.py
CHANGED
|
@@ -12,7 +12,8 @@ from tooluniverse import ToolUniverse
|
|
| 12 |
# Patch PyTorch to allow loading old numpy pickles
|
| 13 |
torch.serialization.add_safe_globals([
|
| 14 |
numpy.core.multiarray._reconstruct,
|
| 15 |
-
numpy.ndarray
|
|
|
|
| 16 |
])
|
| 17 |
|
| 18 |
logging.basicConfig(
|
|
@@ -63,7 +64,11 @@ def patch_embedding_loading():
|
|
| 63 |
tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
|
| 64 |
if len(tools) != len(self.tool_desc_embedding):
|
| 65 |
logger.warning("Tool count mismatch.")
|
| 66 |
-
self.tool_desc_embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
return True
|
| 68 |
except Exception as e:
|
| 69 |
logger.error(f"Embedding load failed: {e}")
|
|
|
|
| 12 |
# Patch PyTorch to allow loading old numpy pickles
|
| 13 |
torch.serialization.add_safe_globals([
|
| 14 |
numpy.core.multiarray._reconstruct,
|
| 15 |
+
numpy.ndarray,
|
| 16 |
+
numpy.dtype
|
| 17 |
])
|
| 18 |
|
| 19 |
logging.basicConfig(
|
|
|
|
| 64 |
tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
|
| 65 |
if len(tools) != len(self.tool_desc_embedding):
|
| 66 |
logger.warning("Tool count mismatch.")
|
| 67 |
+
if len(self.tool_desc_embedding) > len(tools):
|
| 68 |
+
self.tool_desc_embedding = self.tool_desc_embedding[:len(tools)]
|
| 69 |
+
else:
|
| 70 |
+
padding = self.tool_desc_embedding[-1].unsqueeze(0).repeat(len(tools) - len(self.tool_desc_embedding), 1)
|
| 71 |
+
self.tool_desc_embedding = torch.cat([self.tool_desc_embedding, padding], dim=0)
|
| 72 |
return True
|
| 73 |
except Exception as e:
|
| 74 |
logger.error(f"Embedding load failed: {e}")
|