File size: 5,640 Bytes
94b1baf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | # tools.py - Real vector retrieval for query_docs, linter, and test runner
import os
import subprocess
import sys
import tempfile
from dataclasses import dataclass
try:
from sentence_transformers import SentenceTransformer
except ImportError:
SentenceTransformer = None
try:
import chromadb
except ImportError:
chromadb = None
@dataclass
class ToolBox:
_embedder = None
_client = None
_collection = None
@classmethod
def _get_embedder(cls):
if cls._embedder is None:
if SentenceTransformer is None:
return None
cls._embedder = SentenceTransformer("all-MiniLM-L6-v2")
return cls._embedder
@classmethod
def _get_collection(cls):
if cls._collection is None:
if chromadb is None:
return None
cls._client = chromadb.Client()
cls._collection = cls._client.create_collection("docs")
docs = [
"KeyError occurs when a dictionary key is missing. Use dict.get() or check 'if key in dict'.",
"pylint error C0304: missing final newline. Add a newline at the end of file.",
"Deadlock happens when two threads acquire locks in opposite order. Always acquire locks in the same order.",
"Division by zero: check if list is empty before calculating average, or use try/except.",
"Threading.Lock: use 'with lock:' to automatically acquire and release.",
"Off-by-one errors: adjust loop ranges, e.g., range(1, len(arr)-1).",
]
embedder = cls._get_embedder()
if embedder is None:
return None
embeddings = embedder.encode(docs).tolist()
for i, doc in enumerate(docs):
cls._collection.add(ids=[str(i)], documents=[doc], embeddings=[embeddings[i]])
return cls._collection
@staticmethod
def run_linter(code: str) -> str:
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
f.write(code)
f.flush()
tmp_path = f.name
try:
result = subprocess.run(
[sys.executable, "-m", "pylint", tmp_path, "--exit-zero", "--output-format=text"],
capture_output=True,
text=True,
timeout=10,
encoding="utf-8",
)
output = result.stdout
if "Your code has been rated" in output:
output = output.split("Your code has been rated")[0]
output = output.strip()
if not output:
return "No linting issues found."
return output[:500]
except FileNotFoundError:
return "Linter (pylint) not installed."
except subprocess.TimeoutExpired:
return "Linter timed out."
except Exception as e:
return f"Linter error: {str(e)}"
finally:
try:
os.unlink(tmp_path)
except OSError:
pass
@staticmethod
def run_tests(test_script: str) -> str:
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
f.write(test_script)
f.flush()
tmp_path = f.name
try:
result = subprocess.run(
[sys.executable, tmp_path],
capture_output=True,
text=True,
timeout=10,
encoding="utf-8",
)
output = result.stdout + result.stderr
return output.strip() or "Test executed successfully (no output)."
except subprocess.TimeoutExpired:
return "Test execution timed out."
except Exception as e:
return f"Test runner error: {str(e)}"
finally:
try:
os.unlink(tmp_path)
except OSError:
pass
@classmethod
def query_docs(cls, topic: str) -> str:
"""Retrieve top 3 relevant docs; fall back cleanly when vector deps are missing."""
try:
embedder = cls._get_embedder()
collection = cls._get_collection()
if embedder is None or collection is None:
raise RuntimeError("Vector retrieval dependencies are unavailable")
query_emb = embedder.encode([topic]).tolist()
results = collection.query(query_embeddings=query_emb, n_results=3)
if results["documents"] and results["documents"][0]:
snippets = []
for i, doc in enumerate(results["documents"][0]):
snippets.append(f"[{i + 1}] {doc}")
return "Relevant documentation:\n" + "\n".join(snippets)
return "No relevant documentation found."
except Exception:
topic_lower = topic.lower()
fallback = {
"null check": "To avoid KeyError, use 'if key in dict:' before accessing.",
"keyerror": "Catch KeyError with try/except or use dict.get().",
"deadlock": "Always acquire locks in the same order to avoid deadlock.",
"race": "Protect shared state with a lock or make the update atomic.",
"division": "Guard empty inputs before dividing or return a safe default.",
}
for key, value in fallback.items():
if key in topic_lower:
return value
return "No relevant documentation found. Try being more specific."
|