Spaces:
Sleeping
Sleeping
File size: 7,944 Bytes
e272f4f e0fb2f6 e272f4f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | import logging
import ollama
from typing import List, Dict, Optional
from src.config import Config
import os
class OllamaMistral:
"""
A class to interact with the Ollama API for Mistral model.
Handles both chat completions and embeddings generation.
"""
def __init__(self):
"""Initialize the Ollama Mistral client with default settings."""
self.logger = logging.getLogger(__name__)
# Initialize Ollama client with default host
self.client = ollama.Client(host='http://localhost:11434')
self.model = 'mistral' # Default model name
async def generate_response(self, prompt: str) -> str:
"""
Asynchronously generate a text response from Mistral model.
Args:
prompt: The input text prompt for the model
Returns:
Generated response text or error message if failed
"""
try:
print(f"[Ollama] Sending prompt:\n{prompt}\n")
# Send chat request to Ollama API
response = self.client.chat(
model=self.model,
messages=[{
'role': 'user',
'content': prompt
}]
)
print(f"[Ollama] Received response:\n{response}\n")
# Handle different response formats from Ollama
if isinstance(response, dict):
if 'message' in response and 'content' in response['message']:
return response['message']['content']
elif hasattr(response, 'message') and hasattr(response.message, 'content'):
return response.message.content
# Fallback: try to convert to string
return str(response)
except Exception as e:
self.logger.error(f"[OllamaMistral] Error generating response: {str(e)}", exc_info=True)
return f"Error generating response: {str(e)}"
def generate_embedding(self, text: str, model: str = Config.OLLAMA_MODEL) -> Optional[List[float]]:
"""
Generate embeddings for the input text using specified model.
Args:
text: Input text to generate embeddings for
model: Model name to use for embeddings (default from Config)
Returns:
List of embeddings or None if failed
"""
try:
print(f"[Ollama] Generating embedding for: {text[:60]}...")
# Request embeddings from Ollama API
response = self.client.embeddings(
model=model,
prompts=[text] # prompts must be a list of strings
)
print(f"[Ollama] Embedding response: {response}")
# Handle different response formats
if isinstance(response, dict) and 'embeddings' in response:
return response['embeddings'][0]
elif isinstance(response, dict) and 'embedding' in response:
return response['embedding']
else:
self.logger.warning(f"Unexpected embedding response format: {response}")
return None
except Exception as e:
self.logger.error(f"[OllamaMistral] Error generating embedding: {str(e)}", exc_info=True)
return None
def generate(self, prompt: str) -> str:
"""
Synchronous wrapper for generate_response.
Args:
prompt: Input text prompt
Returns:
Generated response text
"""
import asyncio
try:
return asyncio.run(self.generate_response(prompt))
except Exception as e:
self.logger.error(f"Error in synchronous generate: {e}")
return f"Error generating response: {str(e)}"
class GeminiProvider:
"""
A class to interact with Google's Gemini API.
Requires GEMINI_API_KEY environment variable.
"""
def __init__(self):
"""Initialize Gemini provider with API key."""
self.logger = logging.getLogger(__name__)
self.api_key = os.getenv('GEMINI_API_KEY')
if not self.api_key:
raise ValueError("GEMINI_API_KEY environment variable is required for Gemini provider")
try:
import google.generativeai as genai
# Configure Gemini API
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel('gemini-1.5-flash')
except ImportError:
raise ImportError("google-generativeai package is required for Gemini provider")
def generate(self, prompt: str) -> str:
"""
Generate text response using Gemini model.
Args:
prompt: Input text prompt
Returns:
Generated response text or error message
"""
try:
response = self.model.generate_content(prompt)
return response.text
except Exception as e:
self.logger.error(f"[Gemini] Error generating response: {str(e)}")
return f"Error generating response: {str(e)}"
class OpenChatProvider:
"""
A class to use OpenChat models locally via transformers.
Requires transformers package to be installed.
"""
def __init__(self):
"""Initialize OpenChat model and tokenizer."""
self.logger = logging.getLogger(__name__)
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load pretrained OpenChat model
self.tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.5-0106")
self.model = AutoModelForCausalLM.from_pretrained("openchat/openchat-3.5-0106")
except ImportError:
raise ImportError("transformers package is required for OpenChat provider")
def generate(self, prompt: str) -> str:
"""
Generate text response using OpenChat model.
Args:
prompt: Input text prompt
Returns:
Generated response text
"""
try:
# Tokenize input and generate response
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = self.model.generate(**inputs, max_length=512, temperature=0.7)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
self.logger.error(f"[OpenChat] Error generating response: {str(e)}")
return f"Error generating response: {str(e)}"
class LLMFactory:
"""
Factory class to create and manage different LLM providers.
Implements the Factory design pattern for LLM provider instantiation.
"""
@staticmethod
def get_provider(model_name: Optional[str] = None) -> any:
"""
Get appropriate LLM provider based on model name.
Args:
model_name: Name of the model ('mistral', 'gemini', 'openchat')
Defaults to 'mistral' if None or unknown
Returns:
Instance of the requested LLM provider
Raises:
ValueError: If required dependencies are missing for the provider
"""
if model_name is None:
model_name = "mistral" # Default to mistral
model_name = model_name.lower()
# Return appropriate provider based on model name
if model_name == "mistral":
return OllamaMistral()
elif model_name == "gemini":
return GeminiProvider()
elif model_name == "openchat":
return OpenChatProvider()
else:
# Default to mistral if unknown model is specified
logging.warning(f"Unknown model '{model_name}', defaulting to mistral")
return OllamaMistral() |