velai / services /text /FalAITextGenerator.py
cansik's picture
Upload folder via script
3025bb3 verified
from __future__ import annotations
import os
from typing import Any, Sequence
from PIL import Image
from services.exceptions import GenerationError
from services.progress import ProgressCallback, call_progress
from services.registry import register_service
from services.text.TextGenerationResult import TextGenerationResult
from services.text.TextGenerator import TextGenerator
from services.utils.fal_service import DEFAULT_FAL_TEXT_MODEL, run_fal
@register_service
class FalAITextGenerator(TextGenerator):
"""Text generator backed by fal.ai OpenRouter "any LLM" endpoint.
This uses the fal model "openrouter/router" by default and expects
an underlying LLM name in the "model" field of the input payload.
Example extra_arguments for Gemini:
extra_arguments={"model": "google/gemini-2.5-flash"}
"""
service_id = "fal_ai_text"
def __init__(
self,
model_name: str | None = None,
api_key: str | None = None,
extra_arguments: dict[str, Any] | None = None,
) -> None:
"""
Parameters
----------
model_name:
fal model slug, typically "openrouter/router".
api_key:
Optional fal.ai API key. If omitted, FAL_KEY from the environment
will be used.
extra_arguments:
Extra fields for the OpenRouter router input, for example:
- {"model": "google/gemini-2.5-flash"}
- {"system_prompt": "...", "temperature": 0.7, "max_tokens": 1024}
"""
super().__init__(model_name=model_name)
self._api_key = api_key
self._extra_arguments: dict[str, Any] = extra_arguments or {}
@classmethod
def default_model_name(cls) -> str:
# This is the fal model id, not the underlying LLM
return DEFAULT_FAL_TEXT_MODEL
def close(self) -> None:
return
@staticmethod
def _extract_text(response: dict[str, Any]) -> str:
# OpenRouter router schema: "output" contains the generated text.
output = response.get("output")
if isinstance(output, str) and output.strip():
return output
text = response.get("text")
if isinstance(text, str) and text.strip():
return text
raise GenerationError(
"fal.ai text model did not return any text output."
)
def _build_arguments(self, prompt: str) -> dict[str, Any]:
arguments: dict[str, Any] = dict(self._extra_arguments)
arguments["prompt"] = prompt
if "model" not in arguments:
env_model = os.getenv("OPENROUTER_DEFAULT_MODEL")
if not env_model:
raise GenerationError(
"FalAITextGenerator requires a target OpenRouter model name. "
"Provide it via extra_arguments['model'] or set the "
"OPENROUTER_DEFAULT_MODEL environment variable."
)
arguments["model"] = env_model
return arguments
def generate(
self,
prompt: str,
images: Sequence[Image.Image] | None = None,
*,
progress: ProgressCallback | None = None,
) -> TextGenerationResult:
call_progress(progress, 0.1, "Encoding inputs for fal.ai text model")
if images:
# The OpenRouter router endpoint is text only.
raise GenerationError(
"FalAITextGenerator does not currently support image inputs."
)
arguments = self._build_arguments(prompt=prompt)
call_progress(progress, 0.4, "Calling fal.ai OpenRouter router model")
response = run_fal(
model=self.model_name,
arguments=arguments,
api_key=self._api_key,
)
call_progress(progress, 0.7, "Decoding text from fal.ai response")
text_output = self._extract_text(response)
call_progress(progress, 0.95, "Preparing fal.ai text result")
return TextGenerationResult(
provider="fal.ai",
model=self.model_name,
text=text_output,
raw_response=response,
)