Spaces:
Configuration error
Configuration error
| import os | |
| import torch | |
| import logging | |
| import preprocess | |
| from qdrant_client import QdrantClient | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Configure Logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Load Qdrant Configuration from Environment | |
| QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost") | |
| QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333)) | |
| # Initialize Qdrant Client | |
| qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) | |
| # Load Sentence Transformer for Query Embeddings | |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # Load GPT-2 from Hugging Face | |
| GPT2_MODEL_NAME = "gpt2" # You can also use "gpt2-medium", "gpt2-large", "gpt2-xl" for larger versions | |
| tokenizer = AutoTokenizer.from_pretrained(GPT2_MODEL_NAME) | |
| gpt2_model = AutoModelForCausalLM.from_pretrained( | |
| GPT2_MODEL_NAME, | |
| torch_dtype=torch.float16, # Lower memory usage | |
| device_map="auto" # Auto-select GPU if available | |
| ) | |
| # Function to Generate Answer Using GPT-2 | |
| def generate_answer(query, context): | |
| """Generates a response using GPT-2 based on the retrieved context.""" | |
| if not context.strip(): | |
| return "I couldn't find relevant information." | |
| prompt = f""" | |
| Context: {context} | |
| Question: {query} | |
| Answer: | |
| """ | |
| inputs = tokenizer(prompt, return_tensors="pt").to(gpt2_model.device) | |
| outputs = gpt2_model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Function to Query Documents from Qdrant | |
| def query_documents(collection_name, user_query, top_k=5, score_threshold=0.5): | |
| """Queries Qdrant, retrieves matching documents, and generates an answer using GPT-2.""" | |
| try: | |
| logging.info(f"🔍 Original Query: {user_query}") | |
| processed_query = preprocess.preprocess_text(user_query) | |
| logging.info(f" Preprocessed Query: {processed_query}") | |
| # Generate Query Embedding | |
| query_vector = embedding_model.encode(processed_query).tolist() | |
| # Search in Qdrant | |
| search_results = qdrant_client.search( | |
| collection_name=collection_name, | |
| query_vector=query_vector, | |
| limit=top_k, | |
| with_payload=True | |
| ) | |
| if not search_results: | |
| logging.warning(" No results found. Try increasing top_k or checking indexing.") | |
| # Filter Results | |
| filtered_results = [ | |
| { | |
| "id": res.id, | |
| "score": res.score, | |
| "text": res.payload.get("text", ""), | |
| } | |
| for res in search_results if res.score >= score_threshold and "text" in res.payload | |
| ] | |
| # Extract Context for Answer Generation | |
| context = " ".join(res["text"] for res in filtered_results) or "No relevant information found." | |
| answer = generate_answer(user_query, context) | |
| return {"answer": answer, "chunks": filtered_results} | |
| except Exception as e: | |
| logging.error(f"Error during query: {e}") | |
| return {"error": str(e)} | |
| # Command-Line Execution | |
| if _name_ == "_main_": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Query documents with GPT-2") | |
| parser.add_argument("--collection", type=str, default="documents", help="Qdrant collection name") | |
| parser.add_argument("--query", type=str, required=True, help="Your search query") | |
| parser.add_argument("--top-k", type=int, default=3, help="Number of results to return") | |
| args = parser.parse_args() | |
| logging.info(f"Querying for: '{args.query}'") | |
| result = query_documents(args.collection, args.query, args.top_k) | |
| if "error" in result: | |
| logging.error(f" Error: {result['error']}") | |
| else: | |
| logging.info("\n=== Generated Answer ===") | |
| print(result["answer"]) | |
| logging.info("\n=== Relevant Chunks ===") | |
| for i, chunk in enumerate(result["chunks"]): | |
| print(f"\nChunk {i+1} (Score: {chunk['score']:.3f}):\n{chunk['text']}") | |