File size: 5,948 Bytes
5d07afe
5d1cbd9
 
 
 
033cde7
 
 
5d07afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d1cbd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d07afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d1cbd9
 
 
 
 
 
 
 
 
21d2c94
5d1cbd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
033cde7
21d2c94
 
5d1cbd9
 
 
 
5d07afe
5d1cbd9
 
5d07afe
 
5d1cbd9
 
5d07afe
5d1cbd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from openai import OpenAI
from config import (
    gemini_api_key_1, gemini_api_key_2, gemini_api_key_3, gemini_api_key_4,
    query_gen_api_key_1, query_gen_api_key_2, query_gen_api_key_3, query_gen_api_key_4
)



from typing import List
client_1 = OpenAI(
    api_key=gemini_api_key_1,
    base_url="https://generativelanguage.googleapis.com/v1beta/"
)
client_2 = OpenAI(
    api_key=gemini_api_key_2,
    base_url="https://generativelanguage.googleapis.com/v1beta/"
)
client_3 = OpenAI(
    api_key=gemini_api_key_3,
    base_url="https://generativelanguage.googleapis.com/v1beta/"
)
client_4 = OpenAI(
    api_key=gemini_api_key_4,
    base_url="https://generativelanguage.googleapis.com/v1beta/"
)

# Query generation clients
query_client_1 = OpenAI(
    api_key=query_gen_api_key_1,
    base_url="https://api.groq.com/openai/v1"
)
query_client_2 = OpenAI(
    api_key=query_gen_api_key_2,
    base_url="https://api.groq.com/openai/v1"
)
query_client_3 = OpenAI(
    api_key=query_gen_api_key_3,
    base_url="https://api.groq.com/openai/v1"
)
query_client_4 = OpenAI(
    api_key=query_gen_api_key_4,
    base_url="https://api.groq.com/openai/v1"
)

def get_chat_completion(query: str, client) -> dict:
        response = client.chat.completions.create(
            model="gemini-2.5-flash",  # Or any other model you want to use
            n=1,
            messages=[
                {"role": "system", "content": "You are an assistant that answers questions only using the given context. Do not guess or add anything outside the context.  "},
                {"role": "user", "content": query}
            ]
        )
        return response.choices[0].message.content

def construct_prompts(questions: List[str], contexts: List[str]) -> List[str]:
    return [
        f"CONTEXT:\n{context}\n\nQUESTION:\n{question}"
        for context, question in zip(contexts, questions)
    ]
import concurrent.futures

gemini_clients = [client_1, client_2, client_3, client_4]

def generate_answers(prompts: List[str], clients: List = gemini_clients) -> List[str]:
    answers = [None] * len(prompts)  # Will fill answers by index

    def gemini_task(idx_client_prompt):
        idx, client, prompt = idx_client_prompt
        ans = get_chat_completion(prompt, client)
        return idx, ans

    tasks = [
        (i, gemini_clients[i % len(gemini_clients)], prompts[i])
        for i in range(len(prompts))
    ]

    with concurrent.futures.ThreadPoolExecutor(max_workers=len(gemini_clients)) as executor:
        # Ensure order by writing result at respective index
        for idx, ans in executor.map(gemini_task, tasks):
            answers[idx] = ans

    return answers

# Query generation functions
query_gen_clients = [query_client_1, query_client_2, query_client_3, query_client_4]

def get_query_generation(user_input: str, client) -> str:
    """Generate search queries based on user input using query generation clients."""
    response = client.chat.completions.create(
        model="meta-llama/llama-4-scout-17b-16e-instruct",
        n=1,
        messages=[
            {"role": "system", "content": "You are an expert at generating search queries. Given a user input, generate the most relevant search query that would help find the best information to answer the user's question. Return only the search query, nothing else."},
            {"role": "user", "content": f"Generate a search query for: {user_input}"}
        ]
    )
    return response.choices[0].message.content

def generate_search_queries(user_inputs: List[str], clients: List = query_gen_clients) -> List[str]:
    """Generate search queries for multiple user inputs using parallel processing."""
    queries = [None] * len(user_inputs)

    def query_gen_task(idx_client_input):
        idx, client, user_input = idx_client_input
        query = get_query_generation(user_input, client)
        return idx, query

    tasks = [
        (i, query_gen_clients[i % len(query_gen_clients)], user_inputs[i])
        for i in range(len(user_inputs))
    ]

    with concurrent.futures.ThreadPoolExecutor(max_workers=len(query_gen_clients)) as executor:
        for idx, query in executor.map(query_gen_task, tasks):
            queries[idx] = query

    return queries

def process_user_queries_with_context(user_inputs: List[str]) -> List[str]:
    from qdrant_setup import get_context_for_questions

    search_queries = generate_search_queries(user_inputs)
    contexts = get_context_for_questions(search_queries)
    prompts = construct_prompts(user_inputs, contexts)
    answers = generate_answers(prompts)
    return answers

if __name__=="__main__":
   # Example user questions
   user_questions = [
         "What is the grace period for premium payment under the National Parivar Mediclaim Plus Policy?",
         "What is the waiting period for pre-existing diseases (PED) to be covered?",
         "Does this policy cover maternity expenses, and what are the conditions?",
         "What is the waiting period for cataract surgery?"
   ]
   process_user_queries_with_context(user_questions)
   
#    print("=== Testing complete pipeline ===")
#    # Test the complete pipeline with query generation
#    answers = process_user_queries_with_context(user_questions)
#    for i, (question, answer) in enumerate(zip(user_questions, answers)):
#        print(f"\nQ{i+1}: {question}")
#        print(f"A{i+1}: {answer}")
   
#    print("\n=== Testing original approach ===")
#    # Original approach for comparison
#    from qdrant_setup import get_context_for_questions
#    context = get_context_for_questions(user_questions)
#    prompts = construct_prompts(user_questions, context)
#    original_answers = generate_answers(prompts)
   
#    print("\nComparison:")
#    for i, (orig, new) in enumerate(zip(original_answers, answers)):
#        print(f"\nQuestion {i+1}:")
#        print(f"Original: {orig[:100]}...")
#        print(f"New: {new[:100]}...")