knowledge-app / querying.py
noelty's picture
add basic files
42da79c
import os
import torch
import logging
import preprocess
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
# Configure Logging
logging.basicConfig(level=logging.INFO)
# Load Qdrant Configuration from Environment
QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333))
# Initialize Qdrant Client
qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
# Load Sentence Transformer for Query Embeddings
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# Load GPT-2 from Hugging Face
GPT2_MODEL_NAME = "gpt2" # You can also use "gpt2-medium", "gpt2-large", "gpt2-xl" for larger versions
tokenizer = AutoTokenizer.from_pretrained(GPT2_MODEL_NAME)
gpt2_model = AutoModelForCausalLM.from_pretrained(
GPT2_MODEL_NAME,
torch_dtype=torch.float16, # Lower memory usage
device_map="auto" # Auto-select GPU if available
)
# Function to Generate Answer Using GPT-2
def generate_answer(query, context):
"""Generates a response using GPT-2 based on the retrieved context."""
if not context.strip():
return "I couldn't find relevant information."
prompt = f"""
Context: {context}
Question: {query}
Answer:
"""
inputs = tokenizer(prompt, return_tensors="pt").to(gpt2_model.device)
outputs = gpt2_model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Function to Query Documents from Qdrant
def query_documents(collection_name, user_query, top_k=5, score_threshold=0.5):
"""Queries Qdrant, retrieves matching documents, and generates an answer using GPT-2."""
try:
logging.info(f"🔍 Original Query: {user_query}")
processed_query = preprocess.preprocess_text(user_query)
logging.info(f" Preprocessed Query: {processed_query}")
# Generate Query Embedding
query_vector = embedding_model.encode(processed_query).tolist()
# Search in Qdrant
search_results = qdrant_client.search(
collection_name=collection_name,
query_vector=query_vector,
limit=top_k,
with_payload=True
)
if not search_results:
logging.warning(" No results found. Try increasing top_k or checking indexing.")
# Filter Results
filtered_results = [
{
"id": res.id,
"score": res.score,
"text": res.payload.get("text", ""),
}
for res in search_results if res.score >= score_threshold and "text" in res.payload
]
# Extract Context for Answer Generation
context = " ".join(res["text"] for res in filtered_results) or "No relevant information found."
answer = generate_answer(user_query, context)
return {"answer": answer, "chunks": filtered_results}
except Exception as e:
logging.error(f"Error during query: {e}")
return {"error": str(e)}
# Command-Line Execution
if _name_ == "_main_":
import argparse
parser = argparse.ArgumentParser(description="Query documents with GPT-2")
parser.add_argument("--collection", type=str, default="documents", help="Qdrant collection name")
parser.add_argument("--query", type=str, required=True, help="Your search query")
parser.add_argument("--top-k", type=int, default=3, help="Number of results to return")
args = parser.parse_args()
logging.info(f"Querying for: '{args.query}'")
result = query_documents(args.collection, args.query, args.top_k)
if "error" in result:
logging.error(f" Error: {result['error']}")
else:
logging.info("\n=== Generated Answer ===")
print(result["answer"])
logging.info("\n=== Relevant Chunks ===")
for i, chunk in enumerate(result["chunks"]):
print(f"\nChunk {i+1} (Score: {chunk['score']:.3f}):\n{chunk['text']}")