dedlepexa's picture
Update app.py
34ead6e verified
Raw
History Blame Contribute Delete
4.64 kB
from fastapi import FastAPI
from fastapi.responses import PlainTextResponse, FileResponse
from diffusers import StableDiffusionPipeline
import torch
import threading
import time
from collections import OrderedDict
import os
from PIL import Image # 🔥 ДОБАВЛЕНО
app = FastAPI()
# =========================
# ⚡ CPU OPTIMIZATION
# =========================
torch.set_num_threads(2)
# =========================
# 🔥 MODEL
# =========================
model_name = "Lykon/dreamshaper-7"
pipe = StableDiffusionPipeline.from_pretrained(
model_name,
torch_dtype=torch.float32,
safety_checker=None
)
pipe = pipe.to("cpu")
pipe.enable_attention_slicing()
# =========================
# 📦 STORAGE
# =========================
db = OrderedDict()
queue = []
progress_db = {}
MAX_HISTORY = 40
NUM_WORKERS = 1
IMG_DIR = "image"
os.makedirs(IMG_DIR, exist_ok=True)
# =========================
# ✂️ 12 SPLIT FUNCTION (NEW)
# =========================
def split_image_into_12(img_path: str):
img = Image.open(img_path)
w, h = img.size
cols = 4
rows = 3
tile_w = w // cols
tile_h = h // rows
base = img_path.replace(".png", "")
index = 1
for r in range(rows):
for c in range(cols):
left = c * tile_w
top = r * tile_h
right = left + tile_w
bottom = top + tile_h
crop = img.crop((left, top, right, bottom))
# 🔥 ПРОСТОЙ ПОРЯДОК 1 → 12
out_path = f"{base}_{index}.png"
crop.save(out_path)
index += 1
# =========================
# 🚀 GENERATION ENGINE
# =========================
def generate_ai_stream(message: str, mode="fast"):
try:
start = time.time()
# ⚡ режимы
if mode == "fast":
steps = 2
cfg = 1.5
else:
steps = 6
cfg = 3.0
progress_db[message] = 0
# 🔥 fake-progress
for i in range(steps):
progress_db[message] = int((i / steps) * 100)
time.sleep(0.12)
image = pipe(
message,
num_inference_steps=steps,
guidance_scale=cfg,
height=256,
width=256
).images[0]
filename = f"{IMG_DIR}/img_{int(time.time()*1000)}.png"
image.save(filename)
# 🔥 NEW: split into 12 parts
split_image_into_12(filename)
progress_db[message] = 100
duration = round(time.time() - start, 2)
result = f"{filename} | {mode} | {duration}s"
except Exception as e:
result = f"error: {str(e)}"
if message in db:
db[message]["reply"] = result
db[message]["status"] = "done"
progress_db.pop(message, None)
return result
# =========================
# 🔄 WORKER
# =========================
def worker():
while True:
if queue:
message, mode = queue.pop(0)
if message in db and db[message]["status"] == "done":
continue
generate_ai_stream(message, mode)
else:
time.sleep(0.03)
threading.Thread(target=worker, daemon=True).start()
# =========================
# 🌐 API
# =========================
@app.get("/")
async def root():
return PlainTextResponse("⚡ CPU LCM Image Generator Running")
@app.get("/fast")
async def fast(message: str):
if message not in db:
db[message] = {"status": "pending", "reply": ""}
queue.append((message, "fast"))
if len(db) > MAX_HISTORY:
db.popitem(last=False)
return PlainTextResponse("accepted")
@app.get("/quality")
async def quality(message: str):
if message not in db:
db[message] = {"status": "pending", "reply": ""}
queue.append((message, "quality"))
if len(db) > MAX_HISTORY:
db.popitem(last=False)
return PlainTextResponse("accepted")
@app.get("/get")
async def get(message: str):
if message not in db:
return PlainTextResponse("not found")
data = db[message]
if data["status"] == "pending":
progress = progress_db.get(message, 0)
return PlainTextResponse(f"generating... {progress}%")
return PlainTextResponse(data["reply"])
@app.get("/image")
async def get_image(path: str):
if not os.path.exists(path):
return PlainTextResponse("file not found")
return FileResponse(path)
# =========================
# 🚀 RUN
# =========================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)