File size: 7,944 Bytes
e272f4f
 
 
e0fb2f6
e272f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import logging
import ollama
from typing import List, Dict, Optional
from src.config import Config
import os


class OllamaMistral:
    """
    A class to interact with the Ollama API for Mistral model.
    Handles both chat completions and embeddings generation.
    """
    
    def __init__(self):
        """Initialize the Ollama Mistral client with default settings."""
        self.logger = logging.getLogger(__name__)
        # Initialize Ollama client with default host
        self.client = ollama.Client(host='http://localhost:11434')
        self.model = 'mistral'  # Default model name

    async def generate_response(self, prompt: str) -> str:
        """
        Asynchronously generate a text response from Mistral model.
        
        Args:
            prompt: The input text prompt for the model
            
        Returns:
            Generated response text or error message if failed
        """
        try:
            print(f"[Ollama] Sending prompt:\n{prompt}\n")
            # Send chat request to Ollama API
            response = self.client.chat(
                model=self.model,
                messages=[{
                    'role': 'user',
                    'content': prompt
                }]
            )
            print(f"[Ollama] Received response:\n{response}\n")

            # Handle different response formats from Ollama
            if isinstance(response, dict):
                if 'message' in response and 'content' in response['message']:
                    return response['message']['content']
            elif hasattr(response, 'message') and hasattr(response.message, 'content'):
                return response.message.content
            # Fallback: try to convert to string
            return str(response)

        except Exception as e:
            self.logger.error(f"[OllamaMistral] Error generating response: {str(e)}", exc_info=True)
            return f"Error generating response: {str(e)}"

    def generate_embedding(self, text: str, model: str = Config.OLLAMA_MODEL) -> Optional[List[float]]:
        """
        Generate embeddings for the input text using specified model.
        
        Args:
            text: Input text to generate embeddings for
            model: Model name to use for embeddings (default from Config)
            
        Returns:
            List of embeddings or None if failed
        """
        try:
            print(f"[Ollama] Generating embedding for: {text[:60]}...")
            # Request embeddings from Ollama API
            response = self.client.embeddings(
                model=model,
                prompts=[text]  # prompts must be a list of strings
            )
            print(f"[Ollama] Embedding response: {response}")

            # Handle different response formats
            if isinstance(response, dict) and 'embeddings' in response:
                return response['embeddings'][0]
            elif isinstance(response, dict) and 'embedding' in response:
                return response['embedding']
            else:
                self.logger.warning(f"Unexpected embedding response format: {response}")
                return None

        except Exception as e:
            self.logger.error(f"[OllamaMistral] Error generating embedding: {str(e)}", exc_info=True)
            return None

    def generate(self, prompt: str) -> str:
        """
        Synchronous wrapper for generate_response.
        
        Args:
            prompt: Input text prompt
            
        Returns:
            Generated response text
        """
        import asyncio
        try:
            return asyncio.run(self.generate_response(prompt))
        except Exception as e:
            self.logger.error(f"Error in synchronous generate: {e}")
            return f"Error generating response: {str(e)}"


class GeminiProvider:
    """
    A class to interact with Google's Gemini API.
    Requires GEMINI_API_KEY environment variable.
    """
    
    def __init__(self):
        """Initialize Gemini provider with API key."""
        self.logger = logging.getLogger(__name__)
        self.api_key = os.getenv('GEMINI_API_KEY')
        if not self.api_key:
            raise ValueError("GEMINI_API_KEY environment variable is required for Gemini provider")
        
        try:
            import google.generativeai as genai
            # Configure Gemini API
            genai.configure(api_key=self.api_key)
            self.model = genai.GenerativeModel('gemini-1.5-flash')
        except ImportError:
            raise ImportError("google-generativeai package is required for Gemini provider")

    def generate(self, prompt: str) -> str:
        """
        Generate text response using Gemini model.
        
        Args:
            prompt: Input text prompt
            
        Returns:
            Generated response text or error message
        """
        try:
            response = self.model.generate_content(prompt)
            return response.text
        except Exception as e:
            self.logger.error(f"[Gemini] Error generating response: {str(e)}")
            return f"Error generating response: {str(e)}"


class OpenChatProvider:
    """
    A class to use OpenChat models locally via transformers.
    Requires transformers package to be installed.
    """
    
    def __init__(self):
        """Initialize OpenChat model and tokenizer."""
        self.logger = logging.getLogger(__name__)
        try:
            from transformers import AutoTokenizer, AutoModelForCausalLM
            # Load pretrained OpenChat model
            self.tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.5-0106")
            self.model = AutoModelForCausalLM.from_pretrained("openchat/openchat-3.5-0106")
        except ImportError:
            raise ImportError("transformers package is required for OpenChat provider")

    def generate(self, prompt: str) -> str:
        """
        Generate text response using OpenChat model.
        
        Args:
            prompt: Input text prompt
            
        Returns:
            Generated response text
        """
        try:
            # Tokenize input and generate response
            inputs = self.tokenizer(prompt, return_tensors="pt")
            outputs = self.model.generate(**inputs, max_length=512, temperature=0.7)
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return response
        except Exception as e:
            self.logger.error(f"[OpenChat] Error generating response: {str(e)}")
            return f"Error generating response: {str(e)}"


class LLMFactory:
    """
    Factory class to create and manage different LLM providers.
    Implements the Factory design pattern for LLM provider instantiation.
    """
    
    @staticmethod
    def get_provider(model_name: Optional[str] = None) -> any:
        """
        Get appropriate LLM provider based on model name.
        
        Args:
            model_name: Name of the model ('mistral', 'gemini', 'openchat')
                      Defaults to 'mistral' if None or unknown
            
        Returns:
            Instance of the requested LLM provider
            
        Raises:
            ValueError: If required dependencies are missing for the provider
        """
        if model_name is None:
            model_name = "mistral"  # Default to mistral
            
        model_name = model_name.lower()
        
        # Return appropriate provider based on model name
        if model_name == "mistral":
            return OllamaMistral()
        elif model_name == "gemini":
            return GeminiProvider()
        elif model_name == "openchat":
            return OpenChatProvider()
        else:
            # Default to mistral if unknown model is specified
            logging.warning(f"Unknown model '{model_name}', defaulting to mistral")
            return OllamaMistral()