import logging import ollama from typing import List, Dict, Optional from src.config import Config import os class OllamaMistral: """ A class to interact with the Ollama API for Mistral model. Handles both chat completions and embeddings generation. """ def __init__(self): """Initialize the Ollama Mistral client with default settings.""" self.logger = logging.getLogger(__name__) # Initialize Ollama client with default host self.client = ollama.Client(host='http://localhost:11434') self.model = 'mistral' # Default model name async def generate_response(self, prompt: str) -> str: """ Asynchronously generate a text response from Mistral model. Args: prompt: The input text prompt for the model Returns: Generated response text or error message if failed """ try: print(f"[Ollama] Sending prompt:\n{prompt}\n") # Send chat request to Ollama API response = self.client.chat( model=self.model, messages=[{ 'role': 'user', 'content': prompt }] ) print(f"[Ollama] Received response:\n{response}\n") # Handle different response formats from Ollama if isinstance(response, dict): if 'message' in response and 'content' in response['message']: return response['message']['content'] elif hasattr(response, 'message') and hasattr(response.message, 'content'): return response.message.content # Fallback: try to convert to string return str(response) except Exception as e: self.logger.error(f"[OllamaMistral] Error generating response: {str(e)}", exc_info=True) return f"Error generating response: {str(e)}" def generate_embedding(self, text: str, model: str = Config.OLLAMA_MODEL) -> Optional[List[float]]: """ Generate embeddings for the input text using specified model. Args: text: Input text to generate embeddings for model: Model name to use for embeddings (default from Config) Returns: List of embeddings or None if failed """ try: print(f"[Ollama] Generating embedding for: {text[:60]}...") # Request embeddings from Ollama API response = self.client.embeddings( model=model, prompts=[text] # prompts must be a list of strings ) print(f"[Ollama] Embedding response: {response}") # Handle different response formats if isinstance(response, dict) and 'embeddings' in response: return response['embeddings'][0] elif isinstance(response, dict) and 'embedding' in response: return response['embedding'] else: self.logger.warning(f"Unexpected embedding response format: {response}") return None except Exception as e: self.logger.error(f"[OllamaMistral] Error generating embedding: {str(e)}", exc_info=True) return None def generate(self, prompt: str) -> str: """ Synchronous wrapper for generate_response. Args: prompt: Input text prompt Returns: Generated response text """ import asyncio try: return asyncio.run(self.generate_response(prompt)) except Exception as e: self.logger.error(f"Error in synchronous generate: {e}") return f"Error generating response: {str(e)}" class GeminiProvider: """ A class to interact with Google's Gemini API. Requires GEMINI_API_KEY environment variable. """ def __init__(self): """Initialize Gemini provider with API key.""" self.logger = logging.getLogger(__name__) self.api_key = os.getenv('GEMINI_API_KEY') if not self.api_key: raise ValueError("GEMINI_API_KEY environment variable is required for Gemini provider") try: import google.generativeai as genai # Configure Gemini API genai.configure(api_key=self.api_key) self.model = genai.GenerativeModel('gemini-1.5-flash') except ImportError: raise ImportError("google-generativeai package is required for Gemini provider") def generate(self, prompt: str) -> str: """ Generate text response using Gemini model. Args: prompt: Input text prompt Returns: Generated response text or error message """ try: response = self.model.generate_content(prompt) return response.text except Exception as e: self.logger.error(f"[Gemini] Error generating response: {str(e)}") return f"Error generating response: {str(e)}" class OpenChatProvider: """ A class to use OpenChat models locally via transformers. Requires transformers package to be installed. """ def __init__(self): """Initialize OpenChat model and tokenizer.""" self.logger = logging.getLogger(__name__) try: from transformers import AutoTokenizer, AutoModelForCausalLM # Load pretrained OpenChat model self.tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.5-0106") self.model = AutoModelForCausalLM.from_pretrained("openchat/openchat-3.5-0106") except ImportError: raise ImportError("transformers package is required for OpenChat provider") def generate(self, prompt: str) -> str: """ Generate text response using OpenChat model. Args: prompt: Input text prompt Returns: Generated response text """ try: # Tokenize input and generate response inputs = self.tokenizer(prompt, return_tensors="pt") outputs = self.model.generate(**inputs, max_length=512, temperature=0.7) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: self.logger.error(f"[OpenChat] Error generating response: {str(e)}") return f"Error generating response: {str(e)}" class LLMFactory: """ Factory class to create and manage different LLM providers. Implements the Factory design pattern for LLM provider instantiation. """ @staticmethod def get_provider(model_name: Optional[str] = None) -> any: """ Get appropriate LLM provider based on model name. Args: model_name: Name of the model ('mistral', 'gemini', 'openchat') Defaults to 'mistral' if None or unknown Returns: Instance of the requested LLM provider Raises: ValueError: If required dependencies are missing for the provider """ if model_name is None: model_name = "mistral" # Default to mistral model_name = model_name.lower() # Return appropriate provider based on model name if model_name == "mistral": return OllamaMistral() elif model_name == "gemini": return GeminiProvider() elif model_name == "openchat": return OpenChatProvider() else: # Default to mistral if unknown model is specified logging.warning(f"Unknown model '{model_name}', defaulting to mistral") return OllamaMistral()