from __future__ import annotations import os from typing import Any, Sequence from PIL import Image from services.exceptions import GenerationError from services.progress import ProgressCallback, call_progress from services.registry import register_service from services.text.TextGenerationResult import TextGenerationResult from services.text.TextGenerator import TextGenerator from services.utils.fal_service import DEFAULT_FAL_TEXT_MODEL, run_fal @register_service class FalAITextGenerator(TextGenerator): """Text generator backed by fal.ai OpenRouter "any LLM" endpoint. This uses the fal model "openrouter/router" by default and expects an underlying LLM name in the "model" field of the input payload. Example extra_arguments for Gemini: extra_arguments={"model": "google/gemini-2.5-flash"} """ service_id = "fal_ai_text" def __init__( self, model_name: str | None = None, api_key: str | None = None, extra_arguments: dict[str, Any] | None = None, ) -> None: """ Parameters ---------- model_name: fal model slug, typically "openrouter/router". api_key: Optional fal.ai API key. If omitted, FAL_KEY from the environment will be used. extra_arguments: Extra fields for the OpenRouter router input, for example: - {"model": "google/gemini-2.5-flash"} - {"system_prompt": "...", "temperature": 0.7, "max_tokens": 1024} """ super().__init__(model_name=model_name) self._api_key = api_key self._extra_arguments: dict[str, Any] = extra_arguments or {} @classmethod def default_model_name(cls) -> str: # This is the fal model id, not the underlying LLM return DEFAULT_FAL_TEXT_MODEL def close(self) -> None: return @staticmethod def _extract_text(response: dict[str, Any]) -> str: # OpenRouter router schema: "output" contains the generated text. output = response.get("output") if isinstance(output, str) and output.strip(): return output text = response.get("text") if isinstance(text, str) and text.strip(): return text raise GenerationError( "fal.ai text model did not return any text output." ) def _build_arguments(self, prompt: str) -> dict[str, Any]: arguments: dict[str, Any] = dict(self._extra_arguments) arguments["prompt"] = prompt if "model" not in arguments: env_model = os.getenv("OPENROUTER_DEFAULT_MODEL") if not env_model: raise GenerationError( "FalAITextGenerator requires a target OpenRouter model name. " "Provide it via extra_arguments['model'] or set the " "OPENROUTER_DEFAULT_MODEL environment variable." ) arguments["model"] = env_model return arguments def generate( self, prompt: str, images: Sequence[Image.Image] | None = None, *, progress: ProgressCallback | None = None, ) -> TextGenerationResult: call_progress(progress, 0.1, "Encoding inputs for fal.ai text model") if images: # The OpenRouter router endpoint is text only. raise GenerationError( "FalAITextGenerator does not currently support image inputs." ) arguments = self._build_arguments(prompt=prompt) call_progress(progress, 0.4, "Calling fal.ai OpenRouter router model") response = run_fal( model=self.model_name, arguments=arguments, api_key=self._api_key, ) call_progress(progress, 0.7, "Decoding text from fal.ai response") text_output = self._extract_text(response) call_progress(progress, 0.95, "Preparing fal.ai text result") return TextGenerationResult( provider="fal.ai", model=self.model_name, text=text_output, raw_response=response, )