File size: 2,056 Bytes
e68d535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from google import genai
from pydantic import BaseModel
from typing import Union, List, Any
import itertools
from google.genai.types import GenerateContentConfig
from openai import OpenAI

class GenerativeModelConfig(BaseModel):
    """Base configuration for vector databases."""

    model_name: str


class GenerativeModel:
    """Abstract base class for vector databases."""

    def __init__(self, config: Any):
        self.config = config


class GeminiModelConfig(BaseModel):
    # Example field for model settings
    model_name: str
    api_keys: List[str]
    temperature: float = 0.0

class OpenAIModelConfig(BaseModel):
    model_name: str
    api_key: str
    temperature: float = 0.0

class GeminiModel(GenerativeModel):
    def __init__(self, config: GeminiModelConfig):
        super().__init__(config)
        self.config.api_keys = list(set(config.api_keys))
        self.clients = [genai.Client(api_key=api_key) for api_key in self.config.api_keys]
        self._client_cycle = itertools.cycle(self.clients)

    def generate_response(
        self,
        prompt: str,
    ) -> str:
        """Generate a response by calling the model selected."""
        client = next(self._client_cycle)
        response = client.models.generate_content(
            model=self.config.model_name,
            contents=prompt,
            config=GenerateContentConfig(temperature=self.config.temperature),
        )
        return response.text

class OpenAIModel(GenerativeModel):
    def __init__(self, config: OpenAIModelConfig):
        super().__init__(config)
        self.client = OpenAI(api_key=config.api_key)

    def generate_response(self, prompt: str, temperature: float = None) -> str:
        if temperature is None:
            temperature = self.config.temperature
            
        response = self.client.chat.completions.create(
            model=self.config.model_name,
            messages=[{"role": "user", "content": prompt}],
            temperature=temperature,
        )
        return response.choices[0].message.content