MarisUK's picture
Maris AI model sync
f440f03 verified
"""Attēlu ģenerēšana ar Stable Diffusion."""
from __future__ import annotations
import base64
import io
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from maris_core.utils.env import get_hf_model
logger = logging.getLogger(__name__)
router = APIRouter()
class ImageRequest(BaseModel):
prompt: str
width: int = 1024
height: int = 1024
steps: int = 30
guidance_scale: float = 7.5
class ImageResponse(BaseModel):
image_url: str
prompt: str
@router.post("/generate", response_model=ImageResponse)
async def generate_image(req: ImageRequest) -> ImageResponse:
"""Ģenerē attēlu pēc teksta apraksta."""
from maris_core.utils.hf_integration import HFIntegration
hf = HFIntegration()
try:
model_id = get_hf_model("IMAGE_MODEL")
import torch # type: ignore
from diffusers import StableDiffusionPipeline # type: ignore
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
image = pipe(
req.prompt,
width=req.width,
height=req.height,
num_inference_steps=req.steps,
guidance_scale=req.guidance_scale,
).images[0]
# Konvertē uz base64 data URL
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
image_url = f"data:image/png;base64,{b64}"
# Saglabā origin atmiņā
await hf.save_generation("image", req.prompt, {"image_b64": b64[:100] + "..."})
return ImageResponse(image_url=image_url, prompt=req.prompt)
except Exception as exc: # noqa: BLE001
logger.error("Attēla ģenerēšanas kļūda: %s", exc)
raise HTTPException(
status_code=503,
detail="Maris AI attēlu ģenerēšana nav pieejama bez konfigurēta IMAGE_MODEL.",
) from exc