Spaces:
Sleeping
Sleeping
| import time | |
| import os | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams | |
| from langchain_qdrant import QdrantVectorStore | |
| from langchain_groq import ChatGroq | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| VECTORSTORE_CACHE = {} | |
| MEMORY_CACHE = {} | |
| def _repo_collection_name(repo_name): | |
| return f"repo_docs_{repo_name}" | |
| def _memory_collection_name(repo_name): | |
| return f"memory_{repo_name}" | |
| def get_embeddings_model(): | |
| return HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| def get_llm(): | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not groq_api_key: | |
| raise ValueError("GROQ_API_KEY is not set") | |
| return ChatGroq( | |
| model="llama-3.1-8b-instant", | |
| temperature=0, | |
| api_key=groq_api_key, | |
| ) | |
| def _invoke_text(llm, prompt): | |
| result = llm.invoke(prompt) | |
| if isinstance(result, str): | |
| return result | |
| content = getattr(result, "content", "") | |
| if isinstance(content, list): | |
| parts = [] | |
| for item in content: | |
| if isinstance(item, str): | |
| parts.append(item) | |
| elif isinstance(item, dict): | |
| text = item.get("text") | |
| if text: | |
| parts.append(text) | |
| return "".join(parts) | |
| return str(content) | |
| def _get_client(): | |
| return QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")) | |
| def _ensure_collection(client, collection_name): | |
| if not client.collection_exists(collection_name): | |
| client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=VectorParams(size=384, distance=Distance.COSINE), | |
| ) | |
| def get_vectorstore(repo_name): | |
| if repo_name in VECTORSTORE_CACHE: | |
| return VECTORSTORE_CACHE[repo_name] | |
| client = _get_client() | |
| embeddings = get_embeddings_model() | |
| collection_name = _repo_collection_name(repo_name) | |
| _ensure_collection(client, collection_name) | |
| vectorstore = QdrantVectorStore( | |
| client=client, | |
| collection_name=collection_name, | |
| embedding=embeddings, | |
| ) | |
| VECTORSTORE_CACHE[repo_name] = vectorstore | |
| return vectorstore | |
| def get_memory_vectorstore(repo_name): | |
| if repo_name in MEMORY_CACHE: | |
| return MEMORY_CACHE[repo_name] | |
| client = _get_client() | |
| embeddings = get_embeddings_model() | |
| collection_name = _memory_collection_name(repo_name) | |
| _ensure_collection(client, collection_name) | |
| memory_store = QdrantVectorStore( | |
| client=client, | |
| collection_name=collection_name, | |
| embedding=embeddings, | |
| ) | |
| MEMORY_CACHE[repo_name] = memory_store | |
| return memory_store | |
| def initialize_repo_caches(repo_name): | |
| get_vectorstore(repo_name) | |
| get_memory_vectorstore(repo_name) | |
| def store_memory(query, response, repo_name): | |
| if len(query.strip()) <= 10: | |
| return | |
| memory_text = f"User: {query}\nAssistant: {response}" | |
| memory_store = get_memory_vectorstore(repo_name) | |
| memory_store.add_texts( | |
| [memory_text], | |
| metadatas=[ | |
| { | |
| "type": "memory", | |
| "timestamp": time.time(), | |
| } | |
| ], | |
| ) | |
| def get_retriever(vectorstore): | |
| return vectorstore.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={"k": 6, "fetch_k": 24}, | |
| ) | |
| def _get_overview_retriever(vectorstore): | |
| return vectorstore.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={"k": 10, "fetch_k": 40}, | |
| ) | |
| def _looks_code_intent(query): | |
| q = query.lower() | |
| code_signals = [ | |
| "function", "method", "class", "module", "file", "implementation", "logic", | |
| "algorithm", "predict", "prediction", "how does", "how is", "where is", "call", | |
| "returns", "parameter", "bug", "error", "traceback", "stack", "refactor" | |
| ] | |
| return any(signal in q for signal in code_signals) | |
| def _looks_overview_intent(query): | |
| q = query.lower().strip() | |
| overview_signals = [ | |
| "what does this repository do", | |
| "what does this repo do", | |
| "what is this repository", | |
| "what is this repo", | |
| "repository summary", | |
| "repo summary", | |
| "overview", | |
| "high level", | |
| "purpose of", | |
| ] | |
| return any(signal in q for signal in overview_signals) | |
| def _select_diverse_docs(docs, max_docs=8, max_per_path=2): | |
| selected = [] | |
| per_path = {} | |
| for doc in docs: | |
| path = doc.metadata.get("path", "") | |
| count = per_path.get(path, 0) | |
| if count >= max_per_path: | |
| continue | |
| selected.append(doc) | |
| per_path[path] = count + 1 | |
| if len(selected) >= max_docs: | |
| break | |
| return selected or docs[:max_docs] | |
| def _rewrite_query(question, conversation_chunks, llm): | |
| if not conversation_chunks: | |
| return question | |
| memory_context = "\n\n".join(conversation_chunks) | |
| rewrite_prompt = f""" | |
| Rewrite the user question into a standalone retrieval query. | |
| Use relevant details from prior conversation only when needed to resolve references. | |
| Keep technical names, filenames, class names, and function names unchanged. | |
| Return only the rewritten query. | |
| Relevant Past Conversation: | |
| {memory_context} | |
| Original Question: | |
| {question} | |
| """ | |
| rewritten = _invoke_text(llm, rewrite_prompt).strip() | |
| if not rewritten: | |
| return question | |
| rewritten = rewritten.replace("\n", " ").strip('"\' ') | |
| return rewritten or question | |
| def ask_question(query, repo_name): | |
| vectorstore = get_vectorstore(repo_name) | |
| llm = get_llm() | |
| memory_store = get_memory_vectorstore(repo_name) | |
| memory_retriever = memory_store.as_retriever(search_kwargs={"k": 3}) | |
| memory_docs = memory_retriever.invoke(query) | |
| conversation_chunks = [d.page_content for d in memory_docs] | |
| rewritten_query = _rewrite_query(query, conversation_chunks, llm) | |
| is_overview_query = _looks_overview_intent(query) or _looks_overview_intent(rewritten_query) | |
| retriever = _get_overview_retriever(vectorstore) if is_overview_query else get_retriever(vectorstore) | |
| repo_docs = retriever.invoke(rewritten_query) | |
| repo_docs = _select_diverse_docs(repo_docs, max_docs=10 if is_overview_query else 8) | |
| if (not is_overview_query) and (_looks_code_intent(query) or _looks_code_intent(rewritten_query)): | |
| code_docs = [d for d in repo_docs if d.metadata.get("type") == "code"] | |
| if code_docs: | |
| repo_docs = _select_diverse_docs(code_docs, max_docs=8) | |
| conversation_context = "\n\n".join([d.page_content for d in memory_docs]) or "None" | |
| code_context = "\n\n".join([doc.page_content for doc in repo_docs]) | |
| context = ( | |
| f"Relevant Past Conversation:\n{conversation_context}\n\n" | |
| f"Relevant Code Context:\n{code_context}\n\n" | |
| f"Question:\n{query}" | |
| ) | |
| prompt = f""" | |
| You are a senior software engineer. | |
| Use: | |
| * Relevant Past Conversation to resolve references like "that function" | |
| * Relevant Code Context for factual answers | |
| If exact answer is missing, infer logically from code and mention it is an inference. | |
| Be concise and technical. | |
| Context: | |
| {context} | |
| """ | |
| response = _invoke_text(llm, prompt) | |
| store_memory(query, response, repo_name) | |
| return response, repo_docs | |