import os from typing import List, Dict, Any, Optional from dotenv import load_dotenv # OpenAI SDK v1 from openai import OpenAI # Groq from groq import Groq # Cohere import cohere load_dotenv() class LLMProvider: def __init__(self) -> None: self.provider = os.getenv("LLM_PROVIDER", "openai").lower() self.llm_model = os.getenv("LLM_MODEL", "gpt-4o-mini") self.embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") self.rerank_provider = os.getenv("RERANK_PROVIDER", "cohere").lower() self.rerank_model = os.getenv("RERANK_MODEL", "rerank-3") self._openai_client: Optional[OpenAI] = None self._groq_client: Optional[Groq] = None self._cohere_client: Optional[cohere.Client] = None # Initialize clients with explicit parameters openai_key = os.getenv("OPENAI_API_KEY") if openai_key: try: self._openai_client = OpenAI(api_key=openai_key) except Exception as e: print(f"Warning: Failed to initialize OpenAI client: {e}") self._openai_client = None groq_key = os.getenv("GROQ_API_KEY") if groq_key: try: self._groq_client = Groq(api_key=groq_key) except Exception as e: print(f"Warning: Failed to initialize Groq client: {e}") self._groq_client = None cohere_key = os.getenv("COHERE_API_KEY") if cohere_key: try: self._cohere_client = cohere.Client(api_key=cohere_key) except Exception as e: print(f"Warning: Failed to initialize Cohere client: {e}") self._cohere_client = None # Embeddings (via OpenAI by default) def embed_texts(self, texts: List[str]) -> List[List[float]]: if not self._openai_client: raise ValueError("Embeddings require OPENAI_API_KEY set in environment") resp = self._openai_client.embeddings.create(model=self.embedding_model, input=texts) return [d.embedding for d in resp.data] # Chat completion via selected provider def chat(self, messages: List[Dict[str, str]], temperature: float = 0.2, max_tokens: int = 512) -> str: if self.provider == "openai": if not self._openai_client: raise ValueError("OPENAI_API_KEY is missing") resp = self._openai_client.chat.completions.create( model=self.llm_model, messages=messages, temperature=temperature, max_tokens=max_tokens, ) return resp.choices[0].message.content or "" elif self.provider == "groq": if not self._groq_client: raise ValueError("GROQ_API_KEY is missing") resp = self._groq_client.chat.completions.create( model=self.llm_model, messages=messages, temperature=temperature, max_tokens=max_tokens, ) return resp.choices[0].message.content or "" else: raise ValueError(f"Unsupported LLM_PROVIDER: {self.provider}") def rerank(self, query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # documents: list of {text: str, metadata: dict, score: float} if self.rerank_provider == "cohere" and self._cohere_client: inputs = [d["text"] for d in documents] result = self._cohere_client.rerank( model=self.rerank_model, query=query, documents=inputs, top_n=len(inputs), ) # result is ordered by relevance ranked: List[Dict[str, Any]] = [] for item in result: idx = item.index doc = documents[idx] ranked.append({**doc, "rerank_score": float(item.relevance_score)}) return ranked # Fallback: return original order return documents