| | from __future__ import annotations |
| |
|
| | import os |
| | from typing import Any, Sequence |
| |
|
| | from PIL import Image |
| |
|
| | from services.exceptions import GenerationError |
| | from services.progress import ProgressCallback, call_progress |
| | from services.registry import register_service |
| | from services.text.TextGenerationResult import TextGenerationResult |
| | from services.text.TextGenerator import TextGenerator |
| | from services.utils.fal_service import DEFAULT_FAL_TEXT_MODEL, run_fal |
| |
|
| |
|
| | @register_service |
| | class FalAITextGenerator(TextGenerator): |
| | """Text generator backed by fal.ai OpenRouter "any LLM" endpoint. |
| | |
| | This uses the fal model "openrouter/router" by default and expects |
| | an underlying LLM name in the "model" field of the input payload. |
| | |
| | Example extra_arguments for Gemini: |
| | |
| | extra_arguments={"model": "google/gemini-2.5-flash"} |
| | """ |
| |
|
| | service_id = "fal_ai_text" |
| |
|
| | def __init__( |
| | self, |
| | model_name: str | None = None, |
| | api_key: str | None = None, |
| | extra_arguments: dict[str, Any] | None = None, |
| | ) -> None: |
| | """ |
| | Parameters |
| | ---------- |
| | model_name: |
| | fal model slug, typically "openrouter/router". |
| | api_key: |
| | Optional fal.ai API key. If omitted, FAL_KEY from the environment |
| | will be used. |
| | extra_arguments: |
| | Extra fields for the OpenRouter router input, for example: |
| | - {"model": "google/gemini-2.5-flash"} |
| | - {"system_prompt": "...", "temperature": 0.7, "max_tokens": 1024} |
| | """ |
| | super().__init__(model_name=model_name) |
| | self._api_key = api_key |
| | self._extra_arguments: dict[str, Any] = extra_arguments or {} |
| |
|
| | @classmethod |
| | def default_model_name(cls) -> str: |
| | |
| | return DEFAULT_FAL_TEXT_MODEL |
| |
|
| | def close(self) -> None: |
| | return |
| |
|
| | @staticmethod |
| | def _extract_text(response: dict[str, Any]) -> str: |
| | |
| | output = response.get("output") |
| | if isinstance(output, str) and output.strip(): |
| | return output |
| |
|
| | text = response.get("text") |
| | if isinstance(text, str) and text.strip(): |
| | return text |
| |
|
| | raise GenerationError( |
| | "fal.ai text model did not return any text output." |
| | ) |
| |
|
| | def _build_arguments(self, prompt: str) -> dict[str, Any]: |
| | arguments: dict[str, Any] = dict(self._extra_arguments) |
| | arguments["prompt"] = prompt |
| |
|
| | if "model" not in arguments: |
| | env_model = os.getenv("OPENROUTER_DEFAULT_MODEL") |
| | if not env_model: |
| | raise GenerationError( |
| | "FalAITextGenerator requires a target OpenRouter model name. " |
| | "Provide it via extra_arguments['model'] or set the " |
| | "OPENROUTER_DEFAULT_MODEL environment variable." |
| | ) |
| | arguments["model"] = env_model |
| |
|
| | return arguments |
| |
|
| | def generate( |
| | self, |
| | prompt: str, |
| | images: Sequence[Image.Image] | None = None, |
| | *, |
| | progress: ProgressCallback | None = None, |
| | ) -> TextGenerationResult: |
| | call_progress(progress, 0.1, "Encoding inputs for fal.ai text model") |
| |
|
| | if images: |
| | |
| | raise GenerationError( |
| | "FalAITextGenerator does not currently support image inputs." |
| | ) |
| |
|
| | arguments = self._build_arguments(prompt=prompt) |
| |
|
| | call_progress(progress, 0.4, "Calling fal.ai OpenRouter router model") |
| |
|
| | response = run_fal( |
| | model=self.model_name, |
| | arguments=arguments, |
| | api_key=self._api_key, |
| | ) |
| |
|
| | call_progress(progress, 0.7, "Decoding text from fal.ai response") |
| |
|
| | text_output = self._extract_text(response) |
| |
|
| | call_progress(progress, 0.95, "Preparing fal.ai text result") |
| |
|
| | return TextGenerationResult( |
| | provider="fal.ai", |
| | model=self.model_name, |
| | text=text_output, |
| | raw_response=response, |
| | ) |
| |
|