Spaces:
Sleeping
Sleeping
| import logging | |
| import asyncio | |
| import json | |
| import ast | |
| from typing import List, Dict, Any, Union | |
| from dotenv import load_dotenv | |
| # LangChain imports | |
| from langchain_openai import ChatOpenAI | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_cohere import ChatCohere | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| # Local imports | |
| from .utils import getconfig, get_auth | |
| # --------------------------------------------------------------------- | |
| # Model / client initialization (non exaustive list of providers) | |
| # --------------------------------------------------------------------- | |
| config = getconfig("params.cfg") | |
| PROVIDER = config.get("generator", "PROVIDER") | |
| MODEL = config.get("generator", "MODEL") | |
| MAX_TOKENS = int(config.get("generator", "MAX_TOKENS")) | |
| TEMPERATURE = float(config.get("generator", "TEMPERATURE")) | |
| INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER") | |
| ORGANIZATION = config.get("generator", "ORGANIZATION") | |
| # Set up authentication for the selected provider | |
| auth_config = get_auth(PROVIDER) | |
| def get_chat_model(): | |
| """Initialize the appropriate LangChain chat model based on provider""" | |
| common_params = { | |
| "temperature": TEMPERATURE, | |
| "max_tokens": MAX_TOKENS, | |
| } | |
| if PROVIDER == "openai": | |
| return ChatOpenAI( | |
| model=MODEL, | |
| openai_api_key=auth_config["api_key"], | |
| **common_params | |
| ) | |
| elif PROVIDER == "anthropic": | |
| return ChatAnthropic( | |
| model=MODEL, | |
| anthropic_api_key=auth_config["api_key"], | |
| **common_params | |
| ) | |
| elif PROVIDER == "cohere": | |
| return ChatCohere( | |
| model=MODEL, | |
| cohere_api_key=auth_config["api_key"], | |
| **common_params | |
| ) | |
| elif PROVIDER == "huggingface": | |
| # Initialize HuggingFaceEndpoint with explicit parameters | |
| llm = HuggingFaceEndpoint( | |
| repo_id=MODEL, | |
| huggingfacehub_api_token=auth_config["api_key"], | |
| task="text-generation", | |
| provider=INFERENCE_PROVIDER, | |
| server_kwargs={"bill_to": ORGANIZATION}, | |
| temperature=TEMPERATURE, | |
| max_new_tokens=MAX_TOKENS | |
| ) | |
| return ChatHuggingFace(llm=llm) | |
| else: | |
| raise ValueError(f"Unsupported provider: {PROVIDER}") | |
| # Initialize provider-agnostic chat model | |
| chat_model = get_chat_model() | |
| # --------------------------------------------------------------------- | |
| # Context processing - may need further refinement (i.e. to manage other data sources) | |
| # --------------------------------------------------------------------- | |
| def extract_relevant_fields(retrieval_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """ | |
| Extract only relevant fields from retrieval results. | |
| Args: | |
| retrieval_results: List of JSON objects from retriever | |
| Returns: | |
| List of processed objects with only relevant fields | |
| """ | |
| retrieval_results = ast.literal_eval(retrieval_results) | |
| processed_results = [] | |
| for result in retrieval_results: | |
| # Extract the answer content | |
| answer = result.get('answer', '') | |
| # Extract document identification from metadata | |
| metadata = result.get('answer_metadata', {}) | |
| doc_info = { | |
| 'answer': answer, | |
| 'filename': metadata.get('filename', 'Unknown'), | |
| 'page': metadata.get('page', 'Unknown'), | |
| 'year': metadata.get('year', 'Unknown'), | |
| 'source': metadata.get('source', 'Unknown'), | |
| 'document_id': metadata.get('_id', 'Unknown') | |
| } | |
| processed_results.append(doc_info) | |
| return processed_results | |
| def format_context_from_results(processed_results: List[Dict[str, Any]]) -> str: | |
| """ | |
| Format processed retrieval results into a context string for the LLM. | |
| Args: | |
| processed_results: List of processed objects with relevant fields | |
| Returns: | |
| Formatted context string | |
| """ | |
| if not processed_results: | |
| return "" | |
| context_parts = [] | |
| for i, result in enumerate(processed_results, 1): | |
| doc_reference = f"[Document {i}: {result['filename']}" | |
| if result['page'] != 'Unknown': | |
| doc_reference += f", Page {result['page']}" | |
| if result['year'] != 'Unknown': | |
| doc_reference += f", Year {result['year']}" | |
| doc_reference += "]" | |
| context_part = f"{doc_reference}\n{result['answer']}\n" | |
| context_parts.append(context_part) | |
| return "\n".join(context_parts) | |
| # --------------------------------------------------------------------- | |
| # Core generation function for both Gradio UI and MCP | |
| # --------------------------------------------------------------------- | |
| async def _call_llm(messages: list) -> str: | |
| """ | |
| Provider-agnostic LLM call using LangChain. | |
| Args: | |
| messages: List of LangChain message objects | |
| Returns: | |
| Generated response content as string | |
| """ | |
| try: | |
| # Use async invoke for better performance | |
| response = await chat_model.ainvoke(messages) | |
| return response.content.strip() | |
| except Exception as e: | |
| logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}") | |
| raise | |
| def build_messages(question: str, context: str) -> list: | |
| """ | |
| Build messages in LangChain format. | |
| Args: | |
| question: The user's question | |
| context: The relevant context for answering | |
| Returns: | |
| List of LangChain message objects | |
| """ | |
| system_content = ( | |
| "You are an expert assistant. Answer the USER question using only the " | |
| "CONTEXT provided. If the context is insufficient say 'I don't know.'" | |
| ) | |
| user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}" | |
| return [ | |
| SystemMessage(content=system_content), | |
| HumanMessage(content=user_content) | |
| ] | |
| async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str: | |
| """ | |
| Generate an answer to a query using provided context through RAG. | |
| This function takes a user query and relevant context, then uses a language model | |
| to generate a comprehensive answer based on the provided information. | |
| Args: | |
| query (str): User query | |
| context (list): List of retrieval result objects (dictionaries) | |
| Returns: | |
| str: The generated answer based on the query and context | |
| """ | |
| if not query.strip(): | |
| return "Error: Query cannot be empty" | |
| # Handle both string context (for Gradio UI) and list context (from retriever) | |
| if isinstance(context, list): | |
| if not context: | |
| return "Error: No retrieval results provided" | |
| # Process the retrieval results | |
| processed_results = extract_relevant_fields(context) | |
| formatted_context = format_context_from_results(processed_results) | |
| if not formatted_context.strip(): | |
| return "Error: No valid content found in retrieval results" | |
| elif isinstance(context, str): | |
| if not context.strip(): | |
| return "Error: Context cannot be empty" | |
| formatted_context = context | |
| else: | |
| return "Error: Context must be either a string or list of retrieval results" | |
| try: | |
| messages = build_messages(query, formatted_context) | |
| answer = await _call_llm(messages) | |
| return answer | |
| except Exception as e: | |
| logging.exception("Generation failed") | |
| return f"Error: {str(e)}" |