Spaces:
Sleeping
Sleeping
File size: 5,948 Bytes
5d07afe 5d1cbd9 033cde7 5d07afe 5d1cbd9 5d07afe 5d1cbd9 21d2c94 5d1cbd9 033cde7 21d2c94 5d1cbd9 5d07afe 5d1cbd9 5d07afe 5d1cbd9 5d07afe 5d1cbd9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | from openai import OpenAI
from config import (
gemini_api_key_1, gemini_api_key_2, gemini_api_key_3, gemini_api_key_4,
query_gen_api_key_1, query_gen_api_key_2, query_gen_api_key_3, query_gen_api_key_4
)
from typing import List
client_1 = OpenAI(
api_key=gemini_api_key_1,
base_url="https://generativelanguage.googleapis.com/v1beta/"
)
client_2 = OpenAI(
api_key=gemini_api_key_2,
base_url="https://generativelanguage.googleapis.com/v1beta/"
)
client_3 = OpenAI(
api_key=gemini_api_key_3,
base_url="https://generativelanguage.googleapis.com/v1beta/"
)
client_4 = OpenAI(
api_key=gemini_api_key_4,
base_url="https://generativelanguage.googleapis.com/v1beta/"
)
# Query generation clients
query_client_1 = OpenAI(
api_key=query_gen_api_key_1,
base_url="https://api.groq.com/openai/v1"
)
query_client_2 = OpenAI(
api_key=query_gen_api_key_2,
base_url="https://api.groq.com/openai/v1"
)
query_client_3 = OpenAI(
api_key=query_gen_api_key_3,
base_url="https://api.groq.com/openai/v1"
)
query_client_4 = OpenAI(
api_key=query_gen_api_key_4,
base_url="https://api.groq.com/openai/v1"
)
def get_chat_completion(query: str, client) -> dict:
response = client.chat.completions.create(
model="gemini-2.5-flash", # Or any other model you want to use
n=1,
messages=[
{"role": "system", "content": "You are an assistant that answers questions only using the given context. Do not guess or add anything outside the context. "},
{"role": "user", "content": query}
]
)
return response.choices[0].message.content
def construct_prompts(questions: List[str], contexts: List[str]) -> List[str]:
return [
f"CONTEXT:\n{context}\n\nQUESTION:\n{question}"
for context, question in zip(contexts, questions)
]
import concurrent.futures
gemini_clients = [client_1, client_2, client_3, client_4]
def generate_answers(prompts: List[str], clients: List = gemini_clients) -> List[str]:
answers = [None] * len(prompts) # Will fill answers by index
def gemini_task(idx_client_prompt):
idx, client, prompt = idx_client_prompt
ans = get_chat_completion(prompt, client)
return idx, ans
tasks = [
(i, gemini_clients[i % len(gemini_clients)], prompts[i])
for i in range(len(prompts))
]
with concurrent.futures.ThreadPoolExecutor(max_workers=len(gemini_clients)) as executor:
# Ensure order by writing result at respective index
for idx, ans in executor.map(gemini_task, tasks):
answers[idx] = ans
return answers
# Query generation functions
query_gen_clients = [query_client_1, query_client_2, query_client_3, query_client_4]
def get_query_generation(user_input: str, client) -> str:
"""Generate search queries based on user input using query generation clients."""
response = client.chat.completions.create(
model="meta-llama/llama-4-scout-17b-16e-instruct",
n=1,
messages=[
{"role": "system", "content": "You are an expert at generating search queries. Given a user input, generate the most relevant search query that would help find the best information to answer the user's question. Return only the search query, nothing else."},
{"role": "user", "content": f"Generate a search query for: {user_input}"}
]
)
return response.choices[0].message.content
def generate_search_queries(user_inputs: List[str], clients: List = query_gen_clients) -> List[str]:
"""Generate search queries for multiple user inputs using parallel processing."""
queries = [None] * len(user_inputs)
def query_gen_task(idx_client_input):
idx, client, user_input = idx_client_input
query = get_query_generation(user_input, client)
return idx, query
tasks = [
(i, query_gen_clients[i % len(query_gen_clients)], user_inputs[i])
for i in range(len(user_inputs))
]
with concurrent.futures.ThreadPoolExecutor(max_workers=len(query_gen_clients)) as executor:
for idx, query in executor.map(query_gen_task, tasks):
queries[idx] = query
return queries
def process_user_queries_with_context(user_inputs: List[str]) -> List[str]:
from qdrant_setup import get_context_for_questions
search_queries = generate_search_queries(user_inputs)
contexts = get_context_for_questions(search_queries)
prompts = construct_prompts(user_inputs, contexts)
answers = generate_answers(prompts)
return answers
if __name__=="__main__":
# Example user questions
user_questions = [
"What is the grace period for premium payment under the National Parivar Mediclaim Plus Policy?",
"What is the waiting period for pre-existing diseases (PED) to be covered?",
"Does this policy cover maternity expenses, and what are the conditions?",
"What is the waiting period for cataract surgery?"
]
process_user_queries_with_context(user_questions)
# print("=== Testing complete pipeline ===")
# # Test the complete pipeline with query generation
# answers = process_user_queries_with_context(user_questions)
# for i, (question, answer) in enumerate(zip(user_questions, answers)):
# print(f"\nQ{i+1}: {question}")
# print(f"A{i+1}: {answer}")
# print("\n=== Testing original approach ===")
# # Original approach for comparison
# from qdrant_setup import get_context_for_questions
# context = get_context_for_questions(user_questions)
# prompts = construct_prompts(user_questions, context)
# original_answers = generate_answers(prompts)
# print("\nComparison:")
# for i, (orig, new) in enumerate(zip(original_answers, answers)):
# print(f"\nQuestion {i+1}:")
# print(f"Original: {orig[:100]}...")
# print(f"New: {new[:100]}...")
|