ebraam1 commited on
Commit
b5e3251
·
verified ·
1 Parent(s): 846e0b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, UploadFile, File, Query
2
- from diffusers import StableDiffusionPipeline
3
  from huggingface_hub import hf_hub_download
4
  from fastapi.responses import StreamingResponse
5
  from PIL import Image
@@ -24,19 +24,22 @@ pipe = StableDiffusionPipeline.from_single_file(
24
  requires_safety_checker=False
25
  ).to("cpu")
26
 
 
 
 
27
  # ---------- تحميل الـ LoRA ودمجه ----------
28
  print("Downloading LoRA...")
29
  lora_path = hf_hub_download(repo_id=MODEL_REPO, filename=LORA_FILE)
30
 
31
  print("Loading LoRA...")
32
  pipe.load_lora_weights(lora_path)
33
- pipe.fuse_lora(lora_scale=1.0) # القوة الكاملة للـ LoRA
34
 
35
- # تحسين الأداء (بدون تغيير الجدول الزمني)
36
  pipe.enable_attention_slicing()
37
  pipe.enable_vae_slicing()
38
 
39
- print("Model ready 🔥 (Base + LoRA fused)")
40
 
41
  # ---------- دوال مساعدة ----------
42
  def to_bytes(img: Image.Image) -> io.BytesIO:
@@ -45,14 +48,14 @@ def to_bytes(img: Image.Image) -> io.BytesIO:
45
  buf.seek(0)
46
  return buf
47
 
48
- # ---------- Text‑to‑Image (GET بسيط بباراميتر تيكست) ----------
49
  @app.get("/txt2img")
50
  def txt2img(
51
  prompt: str = Query(..., description="الوصف اللي عايز تولده"),
52
- steps: int = Query(8, ge=1, le=20),
53
  guidance: float = Query(7.5, ge=1.0, le=20.0),
54
- height: int = Query(512, ge=256, le=768),
55
- width: int = Query(512, ge=256, le=768)
56
  ):
57
  image = pipe(
58
  prompt=prompt,
@@ -63,16 +66,16 @@ def txt2img(
63
  ).images[0]
64
  return StreamingResponse(to_bytes(image), media_type="image/png")
65
 
66
- # ---------- Image‑to‑Image (POST مع رفع صورة + باراميترات) ----------
67
  @app.post("/img2img")
68
  async def img2img(
69
  file: UploadFile = File(...),
70
  prompt: str = Query(""),
71
- steps: int = Query(8, ge=1, le=20),
72
  guidance: float = Query(7.5, ge=1.0, le=20.0),
73
  strength: float = Query(0.6, ge=0.0, le=1.0),
74
- height: int = Query(512, ge=256, le=768),
75
- width: int = Query(512, ge=256, le=768)
76
  ):
77
  img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((width, height))
78
  image = pipe(
 
1
  from fastapi import FastAPI, UploadFile, File, Query
2
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
3
  from huggingface_hub import hf_hub_download
4
  from fastapi.responses import StreamingResponse
5
  from PIL import Image
 
24
  requires_safety_checker=False
25
  ).to("cpu")
26
 
27
+ # ---------- استخدام Scheduler أسرع (DPM-Solver) ----------
28
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
29
+
30
  # ---------- تحميل الـ LoRA ودمجه ----------
31
  print("Downloading LoRA...")
32
  lora_path = hf_hub_download(repo_id=MODEL_REPO, filename=LORA_FILE)
33
 
34
  print("Loading LoRA...")
35
  pipe.load_lora_weights(lora_path)
36
+ pipe.fuse_lora(lora_scale=1.0) # القوة الكاملة
37
 
38
+ # تحسينات الذاكرة
39
  pipe.enable_attention_slicing()
40
  pipe.enable_vae_slicing()
41
 
42
+ print("Model ready 🔥 (Base + LoRA fused with DPM-Solver)")
43
 
44
  # ---------- دوال مساعدة ----------
45
  def to_bytes(img: Image.Image) -> io.BytesIO:
 
48
  buf.seek(0)
49
  return buf
50
 
51
+ # ---------- Text‑to‑Image (GET بسيط) ----------
52
  @app.get("/txt2img")
53
  def txt2img(
54
  prompt: str = Query(..., description="الوصف اللي عايز تولده"),
55
+ steps: int = Query(6, ge=1, le=20), # افتراضي 6 (أسرع)
56
  guidance: float = Query(7.5, ge=1.0, le=20.0),
57
+ height: int = Query(384, ge=256, le=768), # افتراضي 384
58
+ width: int = Query(384, ge=256, le=768) # افتراضي 384
59
  ):
60
  image = pipe(
61
  prompt=prompt,
 
66
  ).images[0]
67
  return StreamingResponse(to_bytes(image), media_type="image/png")
68
 
69
+ # ---------- Image‑to‑Image (POST) ----------
70
  @app.post("/img2img")
71
  async def img2img(
72
  file: UploadFile = File(...),
73
  prompt: str = Query(""),
74
+ steps: int = Query(6, ge=1, le=20),
75
  guidance: float = Query(7.5, ge=1.0, le=20.0),
76
  strength: float = Query(0.6, ge=0.0, le=1.0),
77
+ height: int = Query(384, ge=256, le=768),
78
+ width: int = Query(384, ge=256, le=768)
79
  ):
80
  img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((width, height))
81
  image = pipe(