ChatAI / providers.py
wynai's picture
Upload 6 files
5eb5327 verified
import os, requests
from typing import List, Dict, Any, Optional
# --- Provider base ---
class ProviderError(Exception):
pass
class ChatProvider:
def generate(self, messages: List[Dict[str, Any]], **kwargs) -> str:
raise NotImplementedError
# --- OpenAI ---
class OpenAIProvider(ChatProvider):
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not self.api_key:
raise ProviderError("OpenAI API key missing. Provide it in the sidebar or set OPENAI_API_KEY.")
try:
# Use official library (>=1.0)
from openai import OpenAI
self.client = OpenAI(api_key=self.api_key)
except Exception as e:
raise ProviderError(f"OpenAI library error: {e}")
def generate(self, messages: List[Dict[str, Any]], model: str = "gpt-4o-mini", temperature: float = 0.3, **kwargs) -> str:
try:
resp = self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature
)
return resp.choices[0].message.content or ""
except Exception as e:
raise ProviderError(f"OpenAI API error: {e}")
# --- Ollama ---
class OllamaProvider(ChatProvider):
def __init__(self, base_url: str = "http://localhost:11434"):
self.base_url = base_url.rstrip("/")
def generate(self, messages: List[Dict[str, Any]], model: str = "llama3.1:8b", temperature: float = 0.3, **kwargs) -> str:
url = f"{self.base_url}/api/chat"
payload = {"model": model, "messages": messages, "stream": False, "options": {"temperature": temperature}}
try:
r = requests.post(url, json=payload, timeout=120)
r.raise_for_status()
data = r.json()
msg = data.get("message", {}).get("content", "")
if not msg and "choices" in data:
msg = data["choices"][0]["message"]["content"]
return msg
except Exception as e:
raise ProviderError(f"Ollama API error: {e}")
# --- Helper ---
def convert_streamlit_messages_to_openai(messages: List[Dict[str, Any]]):
converted = []
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
attachments = m.get("attachments", [])
if attachments:
content += "\\n\\n[Attachments]\\n" + "\\n".join([f"- {a.get('name','file')} ({a.get('type','file')})" for a in attachments])
converted.append({"role": role, "content": content})
return converted