Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| from operator import itemgetter | |
| from langchain_groq import ChatGroq | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.runnables import RunnableParallel, RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from langchain_core.documents import Document | |
| from query_expansion import expand_query_simple | |
| from typing import List, Optional | |
| import time | |
| class GroqAPIKeyManager: | |
| def __init__(self, api_keys: List[str]): | |
| self.api_keys = [key for key in api_keys if key and key != "your_groq_api_key_here"] | |
| if not self.api_keys: | |
| raise ValueError("No valid API keys provided!") | |
| self.current_index = 0 | |
| self.failed_keys = set() | |
| self.success_count = {key: 0 for key in self.api_keys} | |
| self.failure_count = {key: 0 for key in self.api_keys} | |
| print(f"API Key Manager: Loaded {len(self.api_keys)} API keys") | |
| def get_current_key(self) -> str: | |
| return self.api_keys[self.current_index] | |
| def mark_success(self, api_key: str): | |
| if api_key in self.success_count: | |
| self.success_count[api_key] += 1 | |
| if api_key in self.failed_keys: | |
| self.failed_keys.remove(api_key) | |
| print(f"API Key #{self.api_keys.index(api_key) + 1} recovered!") | |
| def mark_failure(self, api_key: str): | |
| if api_key in self.failure_count: | |
| self.failure_count[api_key] += 1 | |
| self.failed_keys.add(api_key) | |
| def rotate_to_next_key(self) -> bool: | |
| initial_index = self.current_index | |
| attempts = 0 | |
| while attempts < len(self.api_keys): | |
| self.current_index = (self.current_index + 1) % len(self.api_keys) | |
| attempts += 1 | |
| current_key = self.api_keys[self.current_index] | |
| if attempts >= len(self.api_keys): | |
| print(f"All keys attempted, retrying with key #{self.current_index + 1}") | |
| return True | |
| if current_key not in self.failed_keys: | |
| print(f"Switching to API Key #{self.current_index + 1}") | |
| return True | |
| return False | |
| def get_statistics(self) -> str: | |
| stats = [] | |
| for i, key in enumerate(self.api_keys): | |
| success = self.success_count[key] | |
| failure = self.failure_count[key] | |
| status = "FAILED" if key in self.failed_keys else "ACTIVE" | |
| masked_key = key[:8] + "..." + key[-4:] if len(key) > 12 else "***" | |
| stats.append(f" Key #{i+1} ({masked_key}): {success} success, {failure} failures [{status}]") | |
| return "\n".join(stats) | |
| def load_api_keys_from_hf_secrets() -> List[str]: | |
| api_keys = [] | |
| secret_names = ['GROQ_API_KEY_1', 'GROQ_API_KEY_2', 'GROQ_API_KEY_3', 'GROQ_API_KEY_4'] | |
| print("Loading API keys from Hugging Face Secrets...") | |
| for secret_name in secret_names: | |
| try: | |
| api_key = os.getenv(secret_name) | |
| if api_key and api_key.strip() and api_key != "your_groq_api_key_here": | |
| api_keys.append(api_key.strip()) | |
| print(f" Loaded: {secret_name}") | |
| else: | |
| print(f" Not found or empty: {secret_name}") | |
| except Exception as e: | |
| print(f" Error loading {secret_name}: {str(e)}") | |
| return api_keys | |
| def create_llm_with_fallback( | |
| api_key_manager: GroqAPIKeyManager, | |
| model_name: str, | |
| temperature: float, | |
| max_retries: int = 3 | |
| ) -> ChatGroq: | |
| for attempt in range(max_retries): | |
| current_key = api_key_manager.get_current_key() | |
| try: | |
| llm = ChatGroq( | |
| model_name=model_name, | |
| api_key=current_key, | |
| temperature=temperature | |
| ) | |
| test_result = llm.invoke("test") | |
| api_key_manager.mark_success(current_key) | |
| return llm | |
| except Exception as e: | |
| error_msg = str(e).lower() | |
| api_key_manager.mark_failure(current_key) | |
| if "rate" in error_msg or "limit" in error_msg: | |
| print(f" Rate limit hit on API Key #{api_key_manager.current_index + 1}") | |
| elif "auth" in error_msg or "api" in error_msg: | |
| print(f" Authentication failed on API Key #{api_key_manager.current_index + 1}") | |
| else: | |
| print(f" Error with API Key #{api_key_manager.current_index + 1}: {str(e)[:50]}") | |
| if attempt < max_retries - 1: | |
| if api_key_manager.rotate_to_next_key(): | |
| print(f" Retrying with next API key (Attempt {attempt + 2}/{max_retries})...") | |
| time.sleep(1) | |
| else: | |
| raise ValueError("All API keys failed!") | |
| else: | |
| raise ValueError(f"Failed to initialize LLM after {max_retries} attempts") | |
| raise ValueError("Failed to create LLM with any available API key") | |
| def create_multi_query_retriever(base_retriever, llm, strategy: str = "balanced"): | |
| def multi_query_retrieve(query: str) -> List[Document]: | |
| query_variations = expand_query_simple(query, strategy=strategy, llm=llm) | |
| all_docs = [] | |
| seen_content = set() | |
| for i, query_var in enumerate(query_variations): | |
| try: | |
| docs = base_retriever.invoke(query_var) | |
| for doc in docs: | |
| content_hash = hash(doc.page_content) | |
| if content_hash not in seen_content: | |
| seen_content.add(content_hash) | |
| all_docs.append(doc) | |
| except Exception as e: | |
| print(f" Query Expansion Error (Query {i+1}): {str(e)[:50]}") | |
| continue | |
| print(f" Query Expansion: Retrieved {len(all_docs)} unique documents.") | |
| return all_docs | |
| return multi_query_retrieve | |
| def get_system_prompt(temperature: float) -> str: | |
| if temperature <= 0.4: | |
| return """You are CogniChat, an expert document analysis assistant specializing in comprehensive and well-structured answers. | |
| RESPONSE GUIDELINES: | |
| **Structure & Formatting:** | |
| - Start with a direct answer to the question | |
| - Use **bold** for key terms, important concepts, and technical terminology | |
| - Use bullet points (•) for lists, features, or multiple items | |
| - Use numbered lists (1., 2., 3.) for steps, procedures, or sequential information | |
| - Use ### Headers to organize different sections or topics | |
| - Add blank lines between sections for readability | |
| **Source Citation:** | |
| - Always cite information using: [Source: filename, Page: X] and cite it at the end of the entire answer only | |
| - Place citations at the end of your final answer only | |
| - Do not cite sources within the body of your answer | |
| - Multiple sources: [Source: doc1.pdf, Page: 3; doc2.pdf, Page: 7] | |
| **Completeness:** | |
| - Provide thorough, detailed answers using ALL relevant information from context | |
| - Summarize and properly elaborate each point for increased clarity | |
| - If the question has multiple parts, address each part clearly | |
| **Accuracy:** | |
| - ONLY use information from the provided context documents below | |
| - If information is incomplete, state what IS available and what ISN'T | |
| - If the answer isn't in the context, clearly state: "I cannot find this information in the uploaded documents" | |
| - Never make assumptions or add information not in the context | |
| --- | |
| {context} | |
| --- | |
| Now answer the following question comprehensively using the context above:""" | |
| elif temperature <= 0.8: | |
| return """You are CogniChat, an intelligent document analysis assistant that combines accuracy with engaging communication. | |
| RESPONSE GUIDELINES: | |
| **Communication Style:** | |
| - Present information in a clear, engaging manner | |
| - Use **bold** for emphasis on important concepts | |
| - Balance structure with natural flow | |
| - Make complex topics accessible and interesting | |
| **Content Approach:** | |
| - Ground your response firmly in the provided context | |
| - Add helpful explanations and connections between concepts | |
| - Use analogies or examples when they help clarify ideas (but keep them brief) | |
| - Organize information logically with headers (###) and lists where appropriate | |
| **Source Attribution:** | |
| - Cite sources at the end: [Source: filename, Page: X] | |
| - Be transparent about what the documents do and don't contain | |
| **Accuracy:** | |
| - Base your answer on the context documents provided | |
| - If information is partial, explain what's available | |
| - Acknowledge gaps: "The documents don't cover this aspect" | |
| --- | |
| {context} | |
| --- | |
| Now answer the following question in an engaging yet accurate way:""" | |
| else: # temperature > 0.8 | |
| # Creative BUT CLEAR prompt - REVISED VERSION | |
| return """You are CogniChat, a creative document analyst who makes complex information clear, memorable, and engaging. | |
| YOUR CORE MISSION: **CLARITY FIRST, CREATIVITY SECOND** | |
| Make information easier to understand, not harder. Your creativity should illuminate, not obscure. | |
| **CREATIVE CLARITY PRINCIPLES:** | |
| 1. **Simplify, Don't Complicate** | |
| - Break down complex concepts into simple, digestible parts | |
| - Use everyday language alongside technical terms | |
| - Explain jargon immediately in plain English | |
| - Short sentences for complex ideas, varied length for rhythm | |
| 2. **Smart Use of Examples & Analogies** (Use Sparingly!) | |
| - Only use analogies when they genuinely make something clearer | |
| - Keep analogies simple and relatable (everyday objects/experiences) | |
| - Never use metaphors that require explanation themselves | |
| - If you can explain it directly in simple terms, do that instead | |
| 3. **Engaging Structure** | |
| - Start with the core answer in one clear sentence | |
| - Use **bold** to highlight key takeaways | |
| - Break information into logical chunks with ### headers | |
| - Use bullet points for clarity, not decoration | |
| - Add brief transition phrases to connect ideas smoothly | |
| 4. **Conversational Yet Precise** | |
| - Write like you're explaining to a smart friend | |
| - Use "you" and active voice to engage readers | |
| - Ask occasional rhetorical questions only if they aid understanding | |
| - Vary sentence length to maintain interest | |
| - Use emojis sparingly (1-2 max) and only where they add clarity | |
| 5. **Visual Clarity** | |
| - Strategic use of formatting: **bold** for key terms, *italics* for emphasis | |
| - White space between sections for easy scanning | |
| - Progressive disclosure: simple concepts first, details after | |
| - Numbered lists for sequences, bullets for related items | |
| **WHAT TO AVOID:** | |
| - Flowery or overly descriptive language | |
| - Complex metaphors that need their own explanation | |
| - Long narrative storytelling that buries the facts | |
| - Multiple rhetorical questions in a row | |
| - Overuse of emojis or exclamation points | |
| - Making simple things sound complicated | |
| **ACCURACY BOUNDARIES:** | |
| - Creative explanation and presentation of facts | |
| - Simple, helpful examples from common knowledge | |
| - Reorganizing information for better understanding | |
| - Never invent facts not in the documents | |
| - Don't contradict source material | |
| - If info is missing, say so clearly and briefly | |
| **Source Attribution:** | |
| - End with: [Source: filename, Page: X] | |
| - Keep it simple and clear | |
| --- | |
| {context} | |
| --- | |
| Now, explain the answer clearly and engagingly. Remember: if your grandmother couldn't understand it, simplify more:""" | |
| def create_rag_chain( | |
| retriever, | |
| get_session_history_func, | |
| enable_query_expansion=True, | |
| expansion_strategy="balanced", | |
| model_name: str = "moonshotai/kimi-k2-instruct", | |
| temperature: float = 0.2, | |
| api_keys: Optional[List[str]] = None | |
| ): | |
| if api_keys is None: | |
| api_keys = load_api_keys_from_hf_secrets() | |
| if not api_keys: | |
| raise ValueError( | |
| "No valid API keys found! Please set GROQ_API_KEY or GROQ_API_KEY_1, " | |
| "GROQ_API_KEY_2, GROQ_API_KEY_3, GROQ_API_KEY_4 in your .env file" | |
| ) | |
| api_key_manager = GroqAPIKeyManager(api_keys) | |
| print(f" RAG: Initializing LLM - Model: {model_name}, Temp: {temperature}") | |
| if temperature <= 0.4: | |
| creativity_mode = "FACTUAL & STRUCTURED" | |
| elif temperature <= 0.8: | |
| creativity_mode = "BALANCED & ENGAGING" | |
| else: | |
| creativity_mode = "CREATIVE & STORYTELLING" | |
| print(f"Creativity Mode: {creativity_mode}") | |
| llm = create_llm_with_fallback(api_key_manager, model_name, temperature) | |
| print(f"LLM initialized with API Key #{api_key_manager.current_index + 1}") | |
| if enable_query_expansion: | |
| print(f"RAG: Query Expansion ENABLED (Strategy: {expansion_strategy})") | |
| enhanced_retriever = create_multi_query_retriever( | |
| base_retriever=retriever, | |
| llm=llm, | |
| strategy=expansion_strategy | |
| ) | |
| else: | |
| enhanced_retriever = retriever | |
| rewrite_template = """You are an expert at optimizing search queries for document retrieval. | |
| Given the conversation history and a follow-up question, create a comprehensive standalone question that: | |
| 1. Incorporates all relevant context from the chat history | |
| 2. Expands abbreviations and resolves all pronouns (it, they, this, that, etc.) | |
| 3. Includes key technical terms and concepts that would help find relevant documents | |
| 4. Maintains the original intent, specificity, and detail level | |
| 5. If the question asks for comparison or multiple items, ensure all items are in the query | |
| Chat History: | |
| {chat_history} | |
| Follow-up Question: {question} | |
| Optimized Standalone Question:""" | |
| rewrite_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", rewrite_template), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{question}") | |
| ]) | |
| query_rewriter = rewrite_prompt | llm | StrOutputParser() | |
| def format_docs(docs): | |
| if not docs: | |
| return "No relevant documents found in the knowledge base." | |
| formatted_parts = [] | |
| for i, doc in enumerate(docs, 1): | |
| source = doc.metadata.get('source', 'Unknown Document') | |
| page = doc.metadata.get('page', 'N/A') | |
| rerank_score = doc.metadata.get('rerank_score') | |
| content = doc.page_content.strip() | |
| doc_header = f"{'='*60}\nDOCUMENT {i}\n{'='*60}" | |
| metadata_line = f"Source: {source} | Page: {page}" | |
| if rerank_score: | |
| metadata_line += f" | Relevance: {rerank_score:.3f}" | |
| formatted_parts.append( | |
| f"{doc_header}\n" | |
| f"{metadata_line}\n" | |
| f"{'-'*60}\n" | |
| f"{content}\n" | |
| ) | |
| return f"RETRIEVED CONTEXT ({len(docs)} documents):\n\n" + "\n".join(formatted_parts) | |
| rag_template = get_system_prompt(temperature) | |
| rag_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", rag_template), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{question}"), | |
| ]) | |
| rewriter_input = RunnableParallel({ | |
| "question": itemgetter("question"), | |
| "chat_history": itemgetter("chat_history"), | |
| }) | |
| retrieval_chain = rewriter_input | query_rewriter | enhanced_retriever | format_docs | |
| conversational_rag_chain = RunnableParallel({ | |
| "context": retrieval_chain, | |
| "question": itemgetter("question"), | |
| "chat_history": itemgetter("chat_history"), | |
| }) | rag_prompt | llm | StrOutputParser() | |
| chain_with_memory = RunnableWithMessageHistory( | |
| conversational_rag_chain, | |
| get_session_history_func, | |
| input_messages_key="question", | |
| history_messages_key="chat_history", | |
| ) | |
| print("RAG: Chain created successfully.") | |
| print("\n" + api_key_manager.get_statistics()) | |
| return chain_with_memory, api_key_manager |