Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Dict, Any, Optional | |
| from dotenv import load_dotenv | |
| # OpenAI SDK v1 | |
| from openai import OpenAI | |
| # Groq | |
| from groq import Groq | |
| # Cohere | |
| import cohere | |
| load_dotenv() | |
| class LLMProvider: | |
| def __init__(self) -> None: | |
| self.provider = os.getenv("LLM_PROVIDER", "openai").lower() | |
| self.llm_model = os.getenv("LLM_MODEL", "gpt-4o-mini") | |
| self.embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") | |
| self.rerank_provider = os.getenv("RERANK_PROVIDER", "cohere").lower() | |
| self.rerank_model = os.getenv("RERANK_MODEL", "rerank-3") | |
| self._openai_client: Optional[OpenAI] = None | |
| self._groq_client: Optional[Groq] = None | |
| self._cohere_client: Optional[cohere.Client] = None | |
| # Initialize clients with explicit parameters | |
| openai_key = os.getenv("OPENAI_API_KEY") | |
| if openai_key: | |
| try: | |
| self._openai_client = OpenAI(api_key=openai_key) | |
| except Exception as e: | |
| print(f"Warning: Failed to initialize OpenAI client: {e}") | |
| self._openai_client = None | |
| groq_key = os.getenv("GROQ_API_KEY") | |
| if groq_key: | |
| try: | |
| self._groq_client = Groq(api_key=groq_key) | |
| except Exception as e: | |
| print(f"Warning: Failed to initialize Groq client: {e}") | |
| self._groq_client = None | |
| cohere_key = os.getenv("COHERE_API_KEY") | |
| if cohere_key: | |
| try: | |
| self._cohere_client = cohere.Client(api_key=cohere_key) | |
| except Exception as e: | |
| print(f"Warning: Failed to initialize Cohere client: {e}") | |
| self._cohere_client = None | |
| # Embeddings (via OpenAI by default) | |
| def embed_texts(self, texts: List[str]) -> List[List[float]]: | |
| if not self._openai_client: | |
| raise ValueError("Embeddings require OPENAI_API_KEY set in environment") | |
| resp = self._openai_client.embeddings.create(model=self.embedding_model, input=texts) | |
| return [d.embedding for d in resp.data] | |
| # Chat completion via selected provider | |
| def chat(self, messages: List[Dict[str, str]], temperature: float = 0.2, max_tokens: int = 512) -> str: | |
| if self.provider == "openai": | |
| if not self._openai_client: | |
| raise ValueError("OPENAI_API_KEY is missing") | |
| resp = self._openai_client.chat.completions.create( | |
| model=self.llm_model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| return resp.choices[0].message.content or "" | |
| elif self.provider == "groq": | |
| if not self._groq_client: | |
| raise ValueError("GROQ_API_KEY is missing") | |
| resp = self._groq_client.chat.completions.create( | |
| model=self.llm_model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| return resp.choices[0].message.content or "" | |
| else: | |
| raise ValueError(f"Unsupported LLM_PROVIDER: {self.provider}") | |
| def rerank(self, query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| # documents: list of {text: str, metadata: dict, score: float} | |
| if self.rerank_provider == "cohere" and self._cohere_client: | |
| inputs = [d["text"] for d in documents] | |
| result = self._cohere_client.rerank( | |
| model=self.rerank_model, | |
| query=query, | |
| documents=inputs, | |
| top_n=len(inputs), | |
| ) | |
| # result is ordered by relevance | |
| ranked: List[Dict[str, Any]] = [] | |
| for item in result: | |
| idx = item.index | |
| doc = documents[idx] | |
| ranked.append({**doc, "rerank_score": float(item.relevance_score)}) | |
| return ranked | |
| # Fallback: return original order | |
| return documents | |