File size: 4,093 Bytes
42da79c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import torch
import logging
import preprocess
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM

# Configure Logging
logging.basicConfig(level=logging.INFO)

# Load Qdrant Configuration from Environment
QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333))

# Initialize Qdrant Client
qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)

# Load Sentence Transformer for Query Embeddings
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

# Load GPT-2 from Hugging Face
GPT2_MODEL_NAME = "gpt2"  # You can also use "gpt2-medium", "gpt2-large", "gpt2-xl" for larger versions
tokenizer = AutoTokenizer.from_pretrained(GPT2_MODEL_NAME)
gpt2_model = AutoModelForCausalLM.from_pretrained(
    GPT2_MODEL_NAME,
    torch_dtype=torch.float16,  # Lower memory usage
    device_map="auto"  # Auto-select GPU if available
)

# Function to Generate Answer Using GPT-2
def generate_answer(query, context):
    """Generates a response using GPT-2 based on the retrieved context."""
    if not context.strip():
        return "I couldn't find relevant information."

    prompt = f"""
    Context: {context}

    Question: {query}

    Answer:
    """

    inputs = tokenizer(prompt, return_tensors="pt").to(gpt2_model.device)
    outputs = gpt2_model.generate(
        **inputs,
        max_new_tokens=200,
        temperature=0.7,
        top_p=0.9
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Function to Query Documents from Qdrant
def query_documents(collection_name, user_query, top_k=5, score_threshold=0.5):
    """Queries Qdrant, retrieves matching documents, and generates an answer using GPT-2."""
    try:
        logging.info(f"🔍 Original Query: {user_query}")
        processed_query = preprocess.preprocess_text(user_query)
        logging.info(f" Preprocessed Query: {processed_query}")

        # Generate Query Embedding
        query_vector = embedding_model.encode(processed_query).tolist()

        # Search in Qdrant
        search_results = qdrant_client.search(
            collection_name=collection_name,
            query_vector=query_vector,
            limit=top_k,
            with_payload=True
        )

        if not search_results:
            logging.warning(" No results found. Try increasing top_k or checking indexing.")

        # Filter Results
        filtered_results = [
            {
                "id": res.id,
                "score": res.score,
                "text": res.payload.get("text", ""),
            }
            for res in search_results if res.score >= score_threshold and "text" in res.payload
        ]

        # Extract Context for Answer Generation
        context = " ".join(res["text"] for res in filtered_results) or "No relevant information found."
        answer = generate_answer(user_query, context)

        return {"answer": answer, "chunks": filtered_results}

    except Exception as e:
        logging.error(f"Error during query: {e}")
        return {"error": str(e)}

# Command-Line Execution
if _name_ == "_main_":
    import argparse

    parser = argparse.ArgumentParser(description="Query documents with GPT-2")
    parser.add_argument("--collection", type=str, default="documents", help="Qdrant collection name")
    parser.add_argument("--query", type=str, required=True, help="Your search query")
    parser.add_argument("--top-k", type=int, default=3, help="Number of results to return")
    args = parser.parse_args()

    logging.info(f"Querying for: '{args.query}'")
    result = query_documents(args.collection, args.query, args.top_k)

    if "error" in result:
        logging.error(f" Error: {result['error']}")
    else:
        logging.info("\n===  Generated Answer ===")
        print(result["answer"])

        logging.info("\n===  Relevant Chunks ===")
        for i, chunk in enumerate(result["chunks"]):
            print(f"\nChunk {i+1} (Score: {chunk['score']:.3f}):\n{chunk['text']}")