Spaces:
Sleeping
Sleeping
| import os, asyncio, logging | |
| import configparser | |
| import logging | |
| from dotenv import load_dotenv | |
| # LangChain imports | |
| from langchain_openai import ChatOpenAI | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_cohere import ChatCohere | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| # Local .env file | |
| load_dotenv() | |
| def getconfig(configfile_path: str): | |
| """ | |
| Read the config file | |
| Params | |
| ---------------- | |
| configfile_path: file path of .cfg file | |
| """ | |
| config = configparser.ConfigParser() | |
| try: | |
| config.read_file(open(configfile_path)) | |
| return config | |
| except: | |
| logging.warning("config file not found") | |
| # --------------------------------------------------------------------- | |
| # Provider-agnostic authentication and configuration | |
| # --------------------------------------------------------------------- | |
| def get_auth_config(provider: str) -> dict: | |
| """Get authentication configuration for different providers""" | |
| auth_configs = { | |
| "openai": {"api_key": os.getenv("OPENAI_API_KEY")}, | |
| "huggingface": {"api_key": os.getenv("HF_TOKEN")}, | |
| "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")}, | |
| "cohere": {"api_key": os.getenv("COHERE_API_KEY")}, | |
| } | |
| if provider not in auth_configs: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| auth_config = auth_configs[provider] | |
| api_key = auth_config.get("api_key") | |
| if not api_key: | |
| raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.") | |
| return auth_config | |
| # --------------------------------------------------------------------- | |
| # Model / client initialization | |
| # --------------------------------------------------------------------- | |
| config = getconfig("params.cfg") | |
| PROVIDER = config.get("generator", "PROVIDER") | |
| MODEL = config.get("generator", "MODEL") | |
| MAX_TOKENS = int(config.get("generator", "MAX_TOKENS")) | |
| TEMPERATURE = float(config.get("generator", "TEMPERATURE")) | |
| # Set up authentication for the selected provider | |
| auth_config = get_auth_config(PROVIDER) | |
| def get_chat_model(): | |
| """Initialize the appropriate LangChain chat model based on provider""" | |
| common_params = { | |
| "temperature": TEMPERATURE, | |
| "max_tokens": MAX_TOKENS, | |
| } | |
| if PROVIDER == "openai": | |
| return ChatOpenAI( | |
| model=MODEL, | |
| openai_api_key=auth_config["api_key"], | |
| **common_params | |
| ) | |
| elif PROVIDER == "anthropic": | |
| return ChatAnthropic( | |
| model=MODEL, | |
| anthropic_api_key=auth_config["api_key"], | |
| **common_params | |
| ) | |
| elif PROVIDER == "cohere": | |
| return ChatCohere( | |
| model=MODEL, | |
| cohere_api_key=auth_config["api_key"], | |
| **common_params | |
| ) | |
| elif PROVIDER == "huggingface": | |
| # Initialize HuggingFaceEndpoint with explicit parameters | |
| llm = HuggingFaceEndpoint( | |
| repo_id=MODEL, | |
| huggingfacehub_api_token=auth_config["api_key"], | |
| task="text-generation", | |
| temperature=TEMPERATURE, | |
| max_new_tokens=MAX_TOKENS | |
| ) | |
| return ChatHuggingFace(llm=llm) | |
| else: | |
| raise ValueError(f"Unsupported provider: {PROVIDER}") | |
| # Initialize provider-agnostic chat model | |
| chat_model = get_chat_model() | |
| # --------------------------------------------------------------------- | |
| # Core generation function for both Gradio UI and MCP | |
| # --------------------------------------------------------------------- | |
| async def _call_llm(messages: list) -> str: | |
| """ | |
| Provider-agnostic LLM call using LangChain. | |
| Args: | |
| messages: List of LangChain message objects | |
| Returns: | |
| Generated response content as string | |
| """ | |
| try: | |
| # Use async invoke for better performance | |
| response = await chat_model.ainvoke(messages) | |
| return response.content.strip() | |
| except Exception as e: | |
| logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}") | |
| raise | |
| def build_messages(question: str, context: str) -> list: | |
| """ | |
| Build messages in LangChain format. | |
| Args: | |
| question: The user's question | |
| context: The relevant context for answering | |
| Returns: | |
| List of LangChain message objects | |
| """ | |
| system_content = ( | |
| "You are an expert assistant. Answer the USER question using only the " | |
| "CONTEXT provided. If the context is insufficient say 'I don't know.'" | |
| ) | |
| user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}" | |
| return [ | |
| SystemMessage(content=system_content), | |
| HumanMessage(content=user_content) | |
| ] | |
| async def rag_generate(query: str, context: str) -> str: | |
| """ | |
| Generate an answer to a query using provided context through RAG. | |
| This function takes a user query and relevant context, then uses a language model | |
| to generate a comprehensive answer based on the provided information. | |
| Args: | |
| query (str): The user's question or query | |
| context (str): The relevant context/documents to use for answering | |
| Returns: | |
| str: The generated answer based on the query and context | |
| """ | |
| if not query.strip(): | |
| return "Error: Query cannot be empty" | |
| if not context.strip(): | |
| return "Error: Context cannot be empty" | |
| try: | |
| messages = build_messages(query, context) | |
| answer = await _call_llm(messages) | |
| return answer | |
| except Exception as e: | |
| logging.exception("Generation failed") | |
| return f"Error: {str(e)}" |