Spaces:
Sleeping
Sleeping
| try: | |
| from ollama import Client | |
| except ImportError: | |
| raise ImportError( | |
| "If you'd like to use Ollama, please install the ollama package by running `pip install ollama`, and set appropriate API keys for the models you want to use." | |
| ) | |
| import json | |
| import os | |
| from typing import List, Union | |
| import platformdirs | |
| from ollama import Image, Message | |
| from .base import CachedEngine, EngineLM | |
| class ChatOllama(EngineLM, CachedEngine): | |
| """ | |
| Ollama implementation of the EngineLM interface. | |
| This allows using any model supported by Ollama. | |
| """ | |
| DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." | |
| def __init__( | |
| self, | |
| model_string="qwen2.5vl:3b", | |
| system_prompt=DEFAULT_SYSTEM_PROMPT, | |
| is_multimodal: bool = False, | |
| use_cache: bool = True, | |
| **kwargs, | |
| ): | |
| """ | |
| :param model_string: | |
| :param system_prompt: | |
| :param is_multimodal: | |
| """ | |
| self.model_string = ( | |
| model_string if ":" in model_string else f"{model_string}:latest" | |
| ) | |
| self.use_cache = use_cache | |
| self.system_prompt = system_prompt | |
| self.is_multimodal = is_multimodal | |
| if self.use_cache: | |
| root = platformdirs.user_cache_dir("agentflow") | |
| cache_path = os.path.join(root, f"cache_ollama_{self.model_string}.db") | |
| self.image_cache_dir = os.path.join(root, "image_cache") | |
| os.makedirs(self.image_cache_dir, exist_ok=True) | |
| super().__init__(cache_path=cache_path) | |
| try: | |
| self.client = Client( | |
| host="http://localhost:11434", | |
| ) | |
| except Exception as e: | |
| raise ValueError(f"Failed to connect to Ollama server: {e}") | |
| models = self.client.list().models | |
| if len(models) == 0: | |
| raise ValueError( | |
| "No models found in the Ollama server. Please ensure the server is running and has models available." | |
| ) | |
| if self.model_string not in [model.model for model in models]: | |
| print( | |
| f"Model '{self.model_string}' not found. Attempting to pull it from the Ollama registry." | |
| ) | |
| try: | |
| self.client.pull(self.model_string) | |
| except Exception as e: | |
| raise ValueError(f"Failed to pull model '{self.model_string}': {e}") | |
| def generate( | |
| self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs | |
| ): | |
| if isinstance(content, str): | |
| return self._generate_text(content, system_prompt=system_prompt, **kwargs) | |
| elif isinstance(content, list): | |
| if not self.is_multimodal: | |
| raise NotImplementedError( | |
| f"Multimodal generation is only supported for {self.model_string}." | |
| ) | |
| return self._generate_multimodal( | |
| content, system_prompt=system_prompt, **kwargs | |
| ) | |
| def _generate_text( | |
| self, | |
| prompt, | |
| system_prompt=None, | |
| temperature=0, | |
| max_tokens=4000, | |
| top_p=0.99, | |
| response_format=None, | |
| ): | |
| sys_prompt_arg = system_prompt if system_prompt else self.system_prompt | |
| if self.use_cache: | |
| cache_key = sys_prompt_arg + prompt | |
| cache_or_none = self._check_cache(cache_key) | |
| if cache_or_none is not None: | |
| return cache_or_none | |
| # Chat models without structured outputs | |
| response = self.client.chat( | |
| model=self.model_string, | |
| messages=[ | |
| {"role": "system", "content": sys_prompt_arg}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| format=response_format.model_json_schema() if response_format else None, | |
| options={ | |
| "frequency_penalty": 0, | |
| "presence_penalty": 0, | |
| "stop": None, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "top_p": top_p, | |
| }, | |
| ) | |
| response = response.message.content | |
| if self.use_cache: | |
| self._save_cache(cache_key, response) | |
| return response | |
| def __call__(self, prompt, **kwargs): | |
| return self.generate(prompt, **kwargs) | |
| def _format_content(self, content: List[Union[str, bytes]]) -> Message: | |
| """ | |
| Formats the input content into a Message object for Ollama. | |
| """ | |
| text_parts = [] | |
| images = [] | |
| for item in content: | |
| if isinstance(item, bytes): | |
| images.append(Image(item)) | |
| elif isinstance(item, str): | |
| text_parts.append(item) | |
| else: | |
| raise ValueError(f"Unsupported input type: {type(item)}") | |
| return Message( | |
| role="user", | |
| content="\n".join(text_parts) if text_parts else None, | |
| images=images if images else None, | |
| ) | |
| def _generate_multimodal( | |
| self, | |
| content: List[Union[str, bytes]], | |
| system_prompt=None, | |
| temperature=0, | |
| max_tokens=4000, | |
| top_p=0.99, | |
| response_format=None, | |
| ): | |
| sys_prompt_arg = system_prompt if system_prompt else self.system_prompt | |
| message = self._format_content(content) | |
| if self.use_cache: | |
| cache_key = sys_prompt_arg + json.dumps(message) | |
| cache_or_none = self._check_cache(cache_key) | |
| if cache_or_none is not None: | |
| return cache_or_none | |
| response = self.client.chat( | |
| model=self.model_string, | |
| messages=[ | |
| {"role": "system", "content": sys_prompt_arg}, | |
| { | |
| "role": message.role, | |
| "content": message.content, | |
| "images": message.images if message.images else None, | |
| }, | |
| ], | |
| format=response_format.model_json_schema() if response_format else None, | |
| options={ | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "top_p": top_p, | |
| }, | |
| ) | |
| response_text = response.message.content | |
| if self.use_cache: | |
| self._save_cache(cache_key, response_text) | |
| return response_text | |