Spaces:
Sleeping
Sleeping
| from openai import OpenAI | |
| from config import ( | |
| gemini_api_key_1, gemini_api_key_2, gemini_api_key_3, gemini_api_key_4, | |
| query_gen_api_key_1, query_gen_api_key_2, query_gen_api_key_3, query_gen_api_key_4 | |
| ) | |
| from typing import List | |
| client_1 = OpenAI( | |
| api_key=gemini_api_key_1, | |
| base_url="https://generativelanguage.googleapis.com/v1beta/" | |
| ) | |
| client_2 = OpenAI( | |
| api_key=gemini_api_key_2, | |
| base_url="https://generativelanguage.googleapis.com/v1beta/" | |
| ) | |
| client_3 = OpenAI( | |
| api_key=gemini_api_key_3, | |
| base_url="https://generativelanguage.googleapis.com/v1beta/" | |
| ) | |
| client_4 = OpenAI( | |
| api_key=gemini_api_key_4, | |
| base_url="https://generativelanguage.googleapis.com/v1beta/" | |
| ) | |
| # Query generation clients | |
| query_client_1 = OpenAI( | |
| api_key=query_gen_api_key_1, | |
| base_url="https://api.groq.com/openai/v1" | |
| ) | |
| query_client_2 = OpenAI( | |
| api_key=query_gen_api_key_2, | |
| base_url="https://api.groq.com/openai/v1" | |
| ) | |
| query_client_3 = OpenAI( | |
| api_key=query_gen_api_key_3, | |
| base_url="https://api.groq.com/openai/v1" | |
| ) | |
| query_client_4 = OpenAI( | |
| api_key=query_gen_api_key_4, | |
| base_url="https://api.groq.com/openai/v1" | |
| ) | |
| def get_chat_completion(query: str, client) -> dict: | |
| response = client.chat.completions.create( | |
| model="gemini-2.5-flash", # Or any other model you want to use | |
| n=1, | |
| messages=[ | |
| {"role": "system", "content": "You are an assistant that answers questions only using the given context. Do not guess or add anything outside the context. "}, | |
| {"role": "user", "content": query} | |
| ] | |
| ) | |
| return response.choices[0].message.content | |
| def construct_prompts(questions: List[str], contexts: List[str]) -> List[str]: | |
| return [ | |
| f"CONTEXT:\n{context}\n\nQUESTION:\n{question}" | |
| for context, question in zip(contexts, questions) | |
| ] | |
| import concurrent.futures | |
| gemini_clients = [client_1, client_2, client_3, client_4] | |
| def generate_answers(prompts: List[str], clients: List = gemini_clients) -> List[str]: | |
| answers = [None] * len(prompts) # Will fill answers by index | |
| def gemini_task(idx_client_prompt): | |
| idx, client, prompt = idx_client_prompt | |
| ans = get_chat_completion(prompt, client) | |
| return idx, ans | |
| tasks = [ | |
| (i, gemini_clients[i % len(gemini_clients)], prompts[i]) | |
| for i in range(len(prompts)) | |
| ] | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=len(gemini_clients)) as executor: | |
| # Ensure order by writing result at respective index | |
| for idx, ans in executor.map(gemini_task, tasks): | |
| answers[idx] = ans | |
| return answers | |
| # Query generation functions | |
| query_gen_clients = [query_client_1, query_client_2, query_client_3, query_client_4] | |
| def get_query_generation(user_input: str, client) -> str: | |
| """Generate search queries based on user input using query generation clients.""" | |
| response = client.chat.completions.create( | |
| model="meta-llama/llama-4-scout-17b-16e-instruct", | |
| n=1, | |
| messages=[ | |
| {"role": "system", "content": "You are an expert at generating search queries. Given a user input, generate the most relevant search query that would help find the best information to answer the user's question. Return only the search query, nothing else."}, | |
| {"role": "user", "content": f"Generate a search query for: {user_input}"} | |
| ] | |
| ) | |
| return response.choices[0].message.content | |
| def generate_search_queries(user_inputs: List[str], clients: List = query_gen_clients) -> List[str]: | |
| """Generate search queries for multiple user inputs using parallel processing.""" | |
| queries = [None] * len(user_inputs) | |
| def query_gen_task(idx_client_input): | |
| idx, client, user_input = idx_client_input | |
| query = get_query_generation(user_input, client) | |
| return idx, query | |
| tasks = [ | |
| (i, query_gen_clients[i % len(query_gen_clients)], user_inputs[i]) | |
| for i in range(len(user_inputs)) | |
| ] | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=len(query_gen_clients)) as executor: | |
| for idx, query in executor.map(query_gen_task, tasks): | |
| queries[idx] = query | |
| return queries | |
| def process_user_queries_with_context(user_inputs: List[str]) -> List[str]: | |
| from qdrant_setup import get_context_for_questions | |
| search_queries = generate_search_queries(user_inputs) | |
| contexts = get_context_for_questions(search_queries) | |
| prompts = construct_prompts(user_inputs, contexts) | |
| answers = generate_answers(prompts) | |
| return answers | |
| if __name__=="__main__": | |
| # Example user questions | |
| user_questions = [ | |
| "What is the grace period for premium payment under the National Parivar Mediclaim Plus Policy?", | |
| "What is the waiting period for pre-existing diseases (PED) to be covered?", | |
| "Does this policy cover maternity expenses, and what are the conditions?", | |
| "What is the waiting period for cataract surgery?" | |
| ] | |
| process_user_queries_with_context(user_questions) | |
| # print("=== Testing complete pipeline ===") | |
| # # Test the complete pipeline with query generation | |
| # answers = process_user_queries_with_context(user_questions) | |
| # for i, (question, answer) in enumerate(zip(user_questions, answers)): | |
| # print(f"\nQ{i+1}: {question}") | |
| # print(f"A{i+1}: {answer}") | |
| # print("\n=== Testing original approach ===") | |
| # # Original approach for comparison | |
| # from qdrant_setup import get_context_for_questions | |
| # context = get_context_for_questions(user_questions) | |
| # prompts = construct_prompts(user_questions, context) | |
| # original_answers = generate_answers(prompts) | |
| # print("\nComparison:") | |
| # for i, (orig, new) in enumerate(zip(original_answers, answers)): | |
| # print(f"\nQuestion {i+1}:") | |
| # print(f"Original: {orig[:100]}...") | |
| # print(f"New: {new[:100]}...") | |