mini-rag / llm.py
navyamehta's picture
Upload 11 files
33f5651 verified
import os
from typing import List, Dict, Any, Optional
from dotenv import load_dotenv
# OpenAI SDK v1
from openai import OpenAI
# Groq
from groq import Groq
# Cohere
import cohere
load_dotenv()
class LLMProvider:
def __init__(self) -> None:
self.provider = os.getenv("LLM_PROVIDER", "openai").lower()
self.llm_model = os.getenv("LLM_MODEL", "gpt-4o-mini")
self.embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
self.rerank_provider = os.getenv("RERANK_PROVIDER", "cohere").lower()
self.rerank_model = os.getenv("RERANK_MODEL", "rerank-3")
self._openai_client: Optional[OpenAI] = None
self._groq_client: Optional[Groq] = None
self._cohere_client: Optional[cohere.Client] = None
# Initialize clients with explicit parameters
openai_key = os.getenv("OPENAI_API_KEY")
if openai_key:
try:
self._openai_client = OpenAI(api_key=openai_key)
except Exception as e:
print(f"Warning: Failed to initialize OpenAI client: {e}")
self._openai_client = None
groq_key = os.getenv("GROQ_API_KEY")
if groq_key:
try:
self._groq_client = Groq(api_key=groq_key)
except Exception as e:
print(f"Warning: Failed to initialize Groq client: {e}")
self._groq_client = None
cohere_key = os.getenv("COHERE_API_KEY")
if cohere_key:
try:
self._cohere_client = cohere.Client(api_key=cohere_key)
except Exception as e:
print(f"Warning: Failed to initialize Cohere client: {e}")
self._cohere_client = None
# Embeddings (via OpenAI by default)
def embed_texts(self, texts: List[str]) -> List[List[float]]:
if not self._openai_client:
raise ValueError("Embeddings require OPENAI_API_KEY set in environment")
resp = self._openai_client.embeddings.create(model=self.embedding_model, input=texts)
return [d.embedding for d in resp.data]
# Chat completion via selected provider
def chat(self, messages: List[Dict[str, str]], temperature: float = 0.2, max_tokens: int = 512) -> str:
if self.provider == "openai":
if not self._openai_client:
raise ValueError("OPENAI_API_KEY is missing")
resp = self._openai_client.chat.completions.create(
model=self.llm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content or ""
elif self.provider == "groq":
if not self._groq_client:
raise ValueError("GROQ_API_KEY is missing")
resp = self._groq_client.chat.completions.create(
model=self.llm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content or ""
else:
raise ValueError(f"Unsupported LLM_PROVIDER: {self.provider}")
def rerank(self, query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# documents: list of {text: str, metadata: dict, score: float}
if self.rerank_provider == "cohere" and self._cohere_client:
inputs = [d["text"] for d in documents]
result = self._cohere_client.rerank(
model=self.rerank_model,
query=query,
documents=inputs,
top_n=len(inputs),
)
# result is ordered by relevance
ranked: List[Dict[str, Any]] = []
for item in result:
idx = item.index
doc = documents[idx]
ranked.append({**doc, "rerank_score": float(item.relevance_score)})
return ranked
# Fallback: return original order
return documents