"""Attēlu ģenerēšana ar Stable Diffusion.""" from __future__ import annotations import base64 import io import logging from fastapi import APIRouter, HTTPException from pydantic import BaseModel from maris_core.utils.env import get_hf_model logger = logging.getLogger(__name__) router = APIRouter() class ImageRequest(BaseModel): prompt: str width: int = 1024 height: int = 1024 steps: int = 30 guidance_scale: float = 7.5 class ImageResponse(BaseModel): image_url: str prompt: str @router.post("/generate", response_model=ImageResponse) async def generate_image(req: ImageRequest) -> ImageResponse: """Ģenerē attēlu pēc teksta apraksta.""" from maris_core.utils.hf_integration import HFIntegration hf = HFIntegration() try: model_id = get_hf_model("IMAGE_MODEL") import torch # type: ignore from diffusers import StableDiffusionPipeline # type: ignore pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16, ) pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu") image = pipe( req.prompt, width=req.width, height=req.height, num_inference_steps=req.steps, guidance_scale=req.guidance_scale, ).images[0] # Konvertē uz base64 data URL buf = io.BytesIO() image.save(buf, format="PNG") b64 = base64.b64encode(buf.getvalue()).decode() image_url = f"data:image/png;base64,{b64}" # Saglabā origin atmiņā await hf.save_generation("image", req.prompt, {"image_b64": b64[:100] + "..."}) return ImageResponse(image_url=image_url, prompt=req.prompt) except Exception as exc: # noqa: BLE001 logger.error("Attēla ģenerēšanas kļūda: %s", exc) raise HTTPException( status_code=503, detail="Maris AI attēlu ģenerēšana nav pieejama bez konfigurēta IMAGE_MODEL.", ) from exc