File size: 2,055 Bytes
52a3b75
 
8512e9b
3d53cf9
 
52a3b75
3d53cf9
52a3b75
 
 
 
3d53cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a3b75
4a17d02
e116903
52a3b75
 
8512e9b
 
2c6d712
3d53cf9
 
 
8512e9b
 
52a3b75
8512e9b
 
 
3d53cf9
52a3b75
 
 
 
 
 
 
 
 
 
 
3d53cf9
8512e9b
3d53cf9
e116903
 
2c6d712
8512e9b
2c6d712
3d53cf9
8512e9b
 
 
2c6d712
 
52a3b75
 
 
3d53cf9
 
 
52a3b75
4a17d02
2c6d712
8512e9b
3d53cf9
2c6d712
8512e9b
2c6d712
 
8512e9b
 
 
2c6d712
 
8512e9b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from fastapi import FastAPI, UploadFile, File
from pydantic import BaseModel
from diffusers import StableDiffusionPipeline
from huggingface_hub import hf_hub_download
from fastapi.responses import StreamingResponse
from PIL import Image
import torch
import io

app = FastAPI()

# =========================
# تحميل LoRA من HF
# =========================
LORA_PATH = hf_hub_download(
    repo_id="ebraam1/interior-sd-models",
    filename="Interior_lora.safetensors"
)

print("Loading base model...")

# =========================
# ✔ FIX: استخدم pretrained model بدل single_file
# =========================
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float32,
    safety_checker=None
).to("cpu")

print("Loading LoRA...")

pipe.load_lora_weights(LORA_PATH)
pipe.fuse_lora(lora_scale=0.8)

# =========================
# ⚡ Speed optimizations
# =========================
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()

print("Model ready 🔥")


# =========================
class Prompt(BaseModel):
    prompt: str


def to_bytes(img):
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    buf.seek(0)
    return buf


# =========================
# TXT2IMG
# =========================
@app.post("/txt2img")
def generate(data: Prompt):

    image = pipe(
        data.prompt,
        num_inference_steps=6,
        guidance_scale=5,
        height=256,
        width=256
    ).images[0]

    return StreamingResponse(to_bytes(image), media_type="image/png")


# =========================
# IMG2IMG (correct way)
# =========================
@app.post("/img2img")
async def img2img_api(file: UploadFile = File(...), prompt: str = ""):

    img = Image.open(io.BytesIO(await file.read())).convert("RGB")
    img = img.resize((256, 256))

    image = pipe(
        prompt=prompt,
        image=img,
        strength=0.6,
        num_inference_steps=6,
        guidance_scale=5
    ).images[0]

    return StreamingResponse(to_bytes(image), media_type="image/png")