|
|
import os, requests |
|
|
from typing import List, Dict, Any, Optional |
|
|
|
|
|
|
|
|
class ProviderError(Exception): |
|
|
pass |
|
|
|
|
|
class ChatProvider: |
|
|
def generate(self, messages: List[Dict[str, Any]], **kwargs) -> str: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class OpenAIProvider(ChatProvider): |
|
|
def __init__(self, api_key: Optional[str] = None): |
|
|
self.api_key = api_key or os.environ.get("OPENAI_API_KEY") |
|
|
if not self.api_key: |
|
|
raise ProviderError("OpenAI API key missing. Provide it in the sidebar or set OPENAI_API_KEY.") |
|
|
|
|
|
try: |
|
|
|
|
|
from openai import OpenAI |
|
|
self.client = OpenAI(api_key=self.api_key) |
|
|
except Exception as e: |
|
|
raise ProviderError(f"OpenAI library error: {e}") |
|
|
|
|
|
def generate(self, messages: List[Dict[str, Any]], model: str = "gpt-4o-mini", temperature: float = 0.3, **kwargs) -> str: |
|
|
try: |
|
|
resp = self.client.chat.completions.create( |
|
|
model=model, |
|
|
messages=messages, |
|
|
temperature=temperature |
|
|
) |
|
|
return resp.choices[0].message.content or "" |
|
|
except Exception as e: |
|
|
raise ProviderError(f"OpenAI API error: {e}") |
|
|
|
|
|
|
|
|
class OllamaProvider(ChatProvider): |
|
|
def __init__(self, base_url: str = "http://localhost:11434"): |
|
|
self.base_url = base_url.rstrip("/") |
|
|
|
|
|
def generate(self, messages: List[Dict[str, Any]], model: str = "llama3.1:8b", temperature: float = 0.3, **kwargs) -> str: |
|
|
url = f"{self.base_url}/api/chat" |
|
|
payload = {"model": model, "messages": messages, "stream": False, "options": {"temperature": temperature}} |
|
|
try: |
|
|
r = requests.post(url, json=payload, timeout=120) |
|
|
r.raise_for_status() |
|
|
data = r.json() |
|
|
msg = data.get("message", {}).get("content", "") |
|
|
if not msg and "choices" in data: |
|
|
msg = data["choices"][0]["message"]["content"] |
|
|
return msg |
|
|
except Exception as e: |
|
|
raise ProviderError(f"Ollama API error: {e}") |
|
|
|
|
|
|
|
|
def convert_streamlit_messages_to_openai(messages: List[Dict[str, Any]]): |
|
|
converted = [] |
|
|
for m in messages: |
|
|
role = m.get("role", "user") |
|
|
content = m.get("content", "") |
|
|
attachments = m.get("attachments", []) |
|
|
if attachments: |
|
|
content += "\\n\\n[Attachments]\\n" + "\\n".join([f"- {a.get('name','file')} ({a.get('type','file')})" for a in attachments]) |
|
|
converted.append({"role": role, "content": content}) |
|
|
return converted |
|
|
|