| from typing import Any |
| from urllib.parse import urljoin |
|
|
| import httpx |
| from langchain_ollama import ChatOllama |
|
|
| from langflow.base.models.model import LCModelComponent |
| from langflow.field_typing import LanguageModel |
| from langflow.inputs.inputs import HandleInput |
| from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, StrInput |
|
|
|
|
| class ChatOllamaComponent(LCModelComponent): |
| display_name = "Ollama" |
| description = "Generate text using Ollama Local LLMs." |
| icon = "Ollama" |
| name = "OllamaModel" |
|
|
| def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): |
| if field_name == "mirostat": |
| if field_value == "Disabled": |
| build_config["mirostat_eta"]["advanced"] = True |
| build_config["mirostat_tau"]["advanced"] = True |
| build_config["mirostat_eta"]["value"] = None |
| build_config["mirostat_tau"]["value"] = None |
|
|
| else: |
| build_config["mirostat_eta"]["advanced"] = False |
| build_config["mirostat_tau"]["advanced"] = False |
|
|
| if field_value == "Mirostat 2.0": |
| build_config["mirostat_eta"]["value"] = 0.2 |
| build_config["mirostat_tau"]["value"] = 10 |
| else: |
| build_config["mirostat_eta"]["value"] = 0.1 |
| build_config["mirostat_tau"]["value"] = 5 |
|
|
| if field_name == "model_name": |
| base_url_dict = build_config.get("base_url", {}) |
| base_url_load_from_db = base_url_dict.get("load_from_db", False) |
| base_url_value = base_url_dict.get("value") |
| if base_url_load_from_db: |
| base_url_value = self.variables(base_url_value, field_name) |
| elif not base_url_value: |
| base_url_value = "http://localhost:11434" |
| build_config["model_name"]["options"] = self.get_model(base_url_value) |
| if field_name == "keep_alive_flag": |
| if field_value == "Keep": |
| build_config["keep_alive"]["value"] = "-1" |
| build_config["keep_alive"]["advanced"] = True |
| elif field_value == "Immediately": |
| build_config["keep_alive"]["value"] = "0" |
| build_config["keep_alive"]["advanced"] = True |
| else: |
| build_config["keep_alive"]["advanced"] = False |
|
|
| return build_config |
|
|
| def get_model(self, base_url_value: str) -> list[str]: |
| try: |
| url = urljoin(base_url_value, "/api/tags") |
| with httpx.Client() as client: |
| response = client.get(url) |
| response.raise_for_status() |
| data = response.json() |
|
|
| return [model["name"] for model in data.get("models", [])] |
| except Exception as e: |
| msg = "Could not retrieve models. Please, make sure Ollama is running." |
| raise ValueError(msg) from e |
|
|
| inputs = [ |
| StrInput( |
| name="base_url", |
| display_name="Base URL", |
| info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.", |
| value="http://localhost:11434", |
| ), |
| DropdownInput( |
| name="model_name", |
| display_name="Model Name", |
| value="llama3.1", |
| info="Refer to https://ollama.com/library for more models.", |
| refresh_button=True, |
| ), |
| FloatInput( |
| name="temperature", |
| display_name="Temperature", |
| value=0.2, |
| info="Controls the creativity of model responses.", |
| ), |
| StrInput( |
| name="format", display_name="Format", info="Specify the format of the output (e.g., json).", advanced=True |
| ), |
| DictInput(name="metadata", display_name="Metadata", info="Metadata to add to the run trace.", advanced=True), |
| DropdownInput( |
| name="mirostat", |
| display_name="Mirostat", |
| options=["Disabled", "Mirostat", "Mirostat 2.0"], |
| info="Enable/disable Mirostat sampling for controlling perplexity.", |
| value="Disabled", |
| advanced=True, |
| real_time_refresh=True, |
| ), |
| FloatInput( |
| name="mirostat_eta", |
| display_name="Mirostat Eta", |
| info="Learning rate for Mirostat algorithm. (Default: 0.1)", |
| advanced=True, |
| ), |
| FloatInput( |
| name="mirostat_tau", |
| display_name="Mirostat Tau", |
| info="Controls the balance between coherence and diversity of the output. (Default: 5.0)", |
| advanced=True, |
| ), |
| IntInput( |
| name="num_ctx", |
| display_name="Context Window Size", |
| info="Size of the context window for generating tokens. (Default: 2048)", |
| advanced=True, |
| ), |
| IntInput( |
| name="num_gpu", |
| display_name="Number of GPUs", |
| info="Number of GPUs to use for computation. (Default: 1 on macOS, 0 to disable)", |
| advanced=True, |
| ), |
| IntInput( |
| name="num_thread", |
| display_name="Number of Threads", |
| info="Number of threads to use during computation. (Default: detected for optimal performance)", |
| advanced=True, |
| ), |
| IntInput( |
| name="repeat_last_n", |
| display_name="Repeat Last N", |
| info="How far back the model looks to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)", |
| advanced=True, |
| ), |
| FloatInput( |
| name="repeat_penalty", |
| display_name="Repeat Penalty", |
| info="Penalty for repetitions in generated text. (Default: 1.1)", |
| advanced=True, |
| ), |
| FloatInput(name="tfs_z", display_name="TFS Z", info="Tail free sampling value. (Default: 1)", advanced=True), |
| IntInput(name="timeout", display_name="Timeout", info="Timeout for the request stream.", advanced=True), |
| IntInput( |
| name="top_k", display_name="Top K", info="Limits token selection to top K. (Default: 40)", advanced=True |
| ), |
| FloatInput(name="top_p", display_name="Top P", info="Works together with top-k. (Default: 0.9)", advanced=True), |
| BoolInput(name="verbose", display_name="Verbose", info="Whether to print out response text.", advanced=True), |
| StrInput( |
| name="tags", |
| display_name="Tags", |
| info="Comma-separated list of tags to add to the run trace.", |
| advanced=True, |
| ), |
| StrInput( |
| name="stop_tokens", |
| display_name="Stop Tokens", |
| info="Comma-separated list of tokens to signal the model to stop generating text.", |
| advanced=True, |
| ), |
| StrInput(name="system", display_name="System", info="System to use for generating text.", advanced=True), |
| StrInput(name="template", display_name="Template", info="Template to use for generating text.", advanced=True), |
| HandleInput( |
| name="output_parser", |
| display_name="Output Parser", |
| info="The parser to use to parse the output of the model", |
| advanced=True, |
| input_types=["OutputParser"], |
| ), |
| *LCModelComponent._base_inputs, |
| ] |
|
|
| def build_model(self) -> LanguageModel: |
| |
| mirostat_options = {"Mirostat": 1, "Mirostat 2.0": 2} |
|
|
| |
| mirostat_value = mirostat_options.get(self.mirostat, 0) |
|
|
| |
| if mirostat_value == 0: |
| mirostat_eta = None |
| mirostat_tau = None |
| else: |
| mirostat_eta = self.mirostat_eta |
| mirostat_tau = self.mirostat_tau |
|
|
| |
| llm_params = { |
| "base_url": self.base_url, |
| "model": self.model_name, |
| "mirostat": mirostat_value, |
| "format": self.format, |
| "metadata": self.metadata, |
| "tags": self.tags.split(",") if self.tags else None, |
| "mirostat_eta": mirostat_eta, |
| "mirostat_tau": mirostat_tau, |
| "num_ctx": self.num_ctx or None, |
| "num_gpu": self.num_gpu or None, |
| "num_thread": self.num_thread or None, |
| "repeat_last_n": self.repeat_last_n or None, |
| "repeat_penalty": self.repeat_penalty or None, |
| "temperature": self.temperature or None, |
| "stop": self.stop_tokens.split(",") if self.stop_tokens else None, |
| "system": self.system, |
| "template": self.template, |
| "tfs_z": self.tfs_z or None, |
| "timeout": self.timeout or None, |
| "top_k": self.top_k or None, |
| "top_p": self.top_p or None, |
| "verbose": self.verbose, |
| } |
|
|
| |
| llm_params = {k: v for k, v in llm_params.items() if v is not None} |
|
|
| try: |
| output = ChatOllama(**llm_params) |
| except Exception as e: |
| msg = "Could not initialize Ollama LLM." |
| raise ValueError(msg) from e |
|
|
| return output |
|
|