manasdhir's picture
minor changes
912dd32
from openai import OpenAI
from config import (
gemini_api_key_1, gemini_api_key_2, gemini_api_key_3, gemini_api_key_4,
query_gen_api_key_1, query_gen_api_key_2, query_gen_api_key_3, query_gen_api_key_4
)
from typing import List
client_1 = OpenAI(
api_key=gemini_api_key_1,
base_url="https://generativelanguage.googleapis.com/v1beta/"
)
client_2 = OpenAI(
api_key=gemini_api_key_2,
base_url="https://generativelanguage.googleapis.com/v1beta/"
)
client_3 = OpenAI(
api_key=gemini_api_key_3,
base_url="https://generativelanguage.googleapis.com/v1beta/"
)
client_4 = OpenAI(
api_key=gemini_api_key_4,
base_url="https://generativelanguage.googleapis.com/v1beta/"
)
# Query generation clients
query_client_1 = OpenAI(
api_key=query_gen_api_key_1,
base_url="https://api.groq.com/openai/v1"
)
query_client_2 = OpenAI(
api_key=query_gen_api_key_2,
base_url="https://api.groq.com/openai/v1"
)
query_client_3 = OpenAI(
api_key=query_gen_api_key_3,
base_url="https://api.groq.com/openai/v1"
)
query_client_4 = OpenAI(
api_key=query_gen_api_key_4,
base_url="https://api.groq.com/openai/v1"
)
def get_chat_completion(query: str, client) -> dict:
response = client.chat.completions.create(
model="gemini-2.5-flash", # Or any other model you want to use
n=1,
messages=[
{"role": "system", "content": "You are an assistant that answers questions only using the given context. Do not guess or add anything outside the context. "},
{"role": "user", "content": query}
]
)
return response.choices[0].message.content
def construct_prompts(questions: List[str], contexts: List[str]) -> List[str]:
return [
f"CONTEXT:\n{context}\n\nQUESTION:\n{question}"
for context, question in zip(contexts, questions)
]
import concurrent.futures
gemini_clients = [client_1, client_2, client_3, client_4]
def generate_answers(prompts: List[str], clients: List = gemini_clients) -> List[str]:
answers = [None] * len(prompts) # Will fill answers by index
def gemini_task(idx_client_prompt):
idx, client, prompt = idx_client_prompt
ans = get_chat_completion(prompt, client)
return idx, ans
tasks = [
(i, gemini_clients[i % len(gemini_clients)], prompts[i])
for i in range(len(prompts))
]
with concurrent.futures.ThreadPoolExecutor(max_workers=len(gemini_clients)) as executor:
# Ensure order by writing result at respective index
for idx, ans in executor.map(gemini_task, tasks):
answers[idx] = ans
return answers
# Query generation functions
query_gen_clients = [query_client_1, query_client_2, query_client_3, query_client_4]
def get_query_generation(user_input: str, client) -> str:
"""Generate search queries based on user input using query generation clients."""
response = client.chat.completions.create(
model="meta-llama/llama-4-scout-17b-16e-instruct",
n=1,
messages=[
{"role": "system", "content": "You are an expert at generating search queries. Given a user input, generate the most relevant search query that would help find the best information to answer the user's question. Return only the search query, nothing else."},
{"role": "user", "content": f"Generate a search query for: {user_input}"}
]
)
return response.choices[0].message.content
def generate_search_queries(user_inputs: List[str], clients: List = query_gen_clients) -> List[str]:
"""Generate search queries for multiple user inputs using parallel processing."""
queries = [None] * len(user_inputs)
def query_gen_task(idx_client_input):
idx, client, user_input = idx_client_input
query = get_query_generation(user_input, client)
return idx, query
tasks = [
(i, query_gen_clients[i % len(query_gen_clients)], user_inputs[i])
for i in range(len(user_inputs))
]
with concurrent.futures.ThreadPoolExecutor(max_workers=len(query_gen_clients)) as executor:
for idx, query in executor.map(query_gen_task, tasks):
queries[idx] = query
return queries
def process_user_queries_with_context(user_inputs: List[str]) -> List[str]:
from qdrant_setup import get_context_for_questions
search_queries = generate_search_queries(user_inputs)
contexts = get_context_for_questions(search_queries)
prompts = construct_prompts(user_inputs, contexts)
answers = generate_answers(prompts)
return answers
if __name__=="__main__":
# Example user questions
user_questions = [
"What is the grace period for premium payment under the National Parivar Mediclaim Plus Policy?",
"What is the waiting period for pre-existing diseases (PED) to be covered?",
"Does this policy cover maternity expenses, and what are the conditions?",
"What is the waiting period for cataract surgery?"
]
process_user_queries_with_context(user_questions)
# print("=== Testing complete pipeline ===")
# # Test the complete pipeline with query generation
# answers = process_user_queries_with_context(user_questions)
# for i, (question, answer) in enumerate(zip(user_questions, answers)):
# print(f"\nQ{i+1}: {question}")
# print(f"A{i+1}: {answer}")
# print("\n=== Testing original approach ===")
# # Original approach for comparison
# from qdrant_setup import get_context_for_questions
# context = get_context_for_questions(user_questions)
# prompts = construct_prompts(user_questions, context)
# original_answers = generate_answers(prompts)
# print("\nComparison:")
# for i, (orig, new) in enumerate(zip(original_answers, answers)):
# print(f"\nQuestion {i+1}:")
# print(f"Original: {orig[:100]}...")
# print(f"New: {new[:100]}...")