File size: 2,645 Bytes
5eb5327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os, requests
from typing import List, Dict, Any, Optional

# --- Provider base ---
class ProviderError(Exception):
    pass

class ChatProvider:
    def generate(self, messages: List[Dict[str, Any]], **kwargs) -> str:
        raise NotImplementedError

# --- OpenAI ---
class OpenAIProvider(ChatProvider):
    def __init__(self, api_key: Optional[str] = None):
        self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
        if not self.api_key:
            raise ProviderError("OpenAI API key missing. Provide it in the sidebar or set OPENAI_API_KEY.")

        try:
            # Use official library (>=1.0)
            from openai import OpenAI
            self.client = OpenAI(api_key=self.api_key)
        except Exception as e:
            raise ProviderError(f"OpenAI library error: {e}")

    def generate(self, messages: List[Dict[str, Any]], model: str = "gpt-4o-mini", temperature: float = 0.3, **kwargs) -> str:
        try:
            resp = self.client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature
            )
            return resp.choices[0].message.content or ""
        except Exception as e:
            raise ProviderError(f"OpenAI API error: {e}")

# --- Ollama ---
class OllamaProvider(ChatProvider):
    def __init__(self, base_url: str = "http://localhost:11434"):
        self.base_url = base_url.rstrip("/")

    def generate(self, messages: List[Dict[str, Any]], model: str = "llama3.1:8b", temperature: float = 0.3, **kwargs) -> str:
        url = f"{self.base_url}/api/chat"
        payload = {"model": model, "messages": messages, "stream": False, "options": {"temperature": temperature}}
        try:
            r = requests.post(url, json=payload, timeout=120)
            r.raise_for_status()
            data = r.json()
            msg = data.get("message", {}).get("content", "")
            if not msg and "choices" in data:
                msg = data["choices"][0]["message"]["content"]
            return msg
        except Exception as e:
            raise ProviderError(f"Ollama API error: {e}")

# --- Helper ---
def convert_streamlit_messages_to_openai(messages: List[Dict[str, Any]]):
    converted = []
    for m in messages:
        role = m.get("role", "user")
        content = m.get("content", "")
        attachments = m.get("attachments", [])
        if attachments:
            content += "\\n\\n[Attachments]\\n" + "\\n".join([f"- {a.get('name','file')} ({a.get('type','file')})" for a in attachments])
        converted.append({"role": role, "content": content})
    return converted