| """Attēlu ģenerēšana ar Stable Diffusion.""" |
|
|
| from __future__ import annotations |
|
|
| import base64 |
| import io |
| import logging |
|
|
| from fastapi import APIRouter, HTTPException |
| from pydantic import BaseModel |
|
|
| from maris_core.utils.env import get_hf_model |
|
|
| logger = logging.getLogger(__name__) |
| router = APIRouter() |
|
|
|
|
| class ImageRequest(BaseModel): |
| prompt: str |
| width: int = 1024 |
| height: int = 1024 |
| steps: int = 30 |
| guidance_scale: float = 7.5 |
|
|
|
|
| class ImageResponse(BaseModel): |
| image_url: str |
| prompt: str |
|
|
|
|
| @router.post("/generate", response_model=ImageResponse) |
| async def generate_image(req: ImageRequest) -> ImageResponse: |
| """Ģenerē attēlu pēc teksta apraksta.""" |
| from maris_core.utils.hf_integration import HFIntegration |
|
|
| hf = HFIntegration() |
|
|
| try: |
| model_id = get_hf_model("IMAGE_MODEL") |
| import torch |
| from diffusers import StableDiffusionPipeline |
|
|
| pipe = StableDiffusionPipeline.from_pretrained( |
| model_id, |
| torch_dtype=torch.float16, |
| ) |
| pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| image = pipe( |
| req.prompt, |
| width=req.width, |
| height=req.height, |
| num_inference_steps=req.steps, |
| guidance_scale=req.guidance_scale, |
| ).images[0] |
|
|
| |
| buf = io.BytesIO() |
| image.save(buf, format="PNG") |
| b64 = base64.b64encode(buf.getvalue()).decode() |
| image_url = f"data:image/png;base64,{b64}" |
|
|
| |
| await hf.save_generation("image", req.prompt, {"image_b64": b64[:100] + "..."}) |
|
|
| return ImageResponse(image_url=image_url, prompt=req.prompt) |
|
|
| except Exception as exc: |
| logger.error("Attēla ģenerēšanas kļūda: %s", exc) |
| raise HTTPException( |
| status_code=503, |
| detail="Maris AI attēlu ģenerēšana nav pieejama bez konfigurēta IMAGE_MODEL.", |
| ) from exc |
|
|