ebraam1 commited on
Commit
2b3a7bb
·
verified ·
1 Parent(s): 19d3cb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -24
app.py CHANGED
@@ -1,5 +1,4 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- from pydantic import BaseModel
3
  from diffusers import StableDiffusionPipeline
4
  from huggingface_hub import hf_hub_download
5
  from fastapi.responses import StreamingResponse
@@ -10,52 +9,77 @@ import io
10
  app = FastAPI()
11
 
12
  MODEL_REPO = "ebraam1/interior-sd-models"
13
- MODEL_FILE = "Interior.safetensors"
 
14
 
15
- print("Downloading model file...")
16
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
 
17
 
18
- print("Loading model on CPU...")
19
  pipe = StableDiffusionPipeline.from_single_file(
20
- model_path,
21
  torch_dtype=torch.float32,
22
  safety_checker=None,
23
  requires_safety_checker=False
24
  ).to("cpu")
25
 
 
 
 
 
 
 
 
 
 
26
  pipe.enable_attention_slicing()
27
  pipe.enable_vae_slicing()
28
 
29
- print("Model ready 🔥")
30
 
31
- class Prompt(BaseModel):
32
- prompt: str
33
-
34
- def to_bytes(img: Image.Image):
35
  buf = io.BytesIO()
36
  img.save(buf, format="PNG")
37
  buf.seek(0)
38
  return buf
39
 
40
- @app.post("/txt2img")
41
- def generate(data: Prompt):
 
 
 
 
 
 
 
42
  image = pipe(
43
- data.prompt,
44
- num_inference_steps=8,
45
- guidance_scale=7.5,
46
- height=512,
47
- width=512
48
  ).images[0]
49
  return StreamingResponse(to_bytes(image), media_type="image/png")
50
 
 
51
  @app.post("/img2img")
52
- async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
53
- img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((512, 512))
 
 
 
 
 
 
 
 
54
  image = pipe(
55
  prompt=prompt,
56
  image=img,
57
- strength=0.6,
58
- num_inference_steps=8,
59
- guidance_scale=7.5
60
  ).images[0]
61
  return StreamingResponse(to_bytes(image), media_type="image/png")
 
1
+ from fastapi import FastAPI, UploadFile, File, Query
 
2
  from diffusers import StableDiffusionPipeline
3
  from huggingface_hub import hf_hub_download
4
  from fastapi.responses import StreamingResponse
 
9
  app = FastAPI()
10
 
11
  MODEL_REPO = "ebraam1/interior-sd-models"
12
+ BASE_MODEL = "Interior.safetensors"
13
+ LORA_FILE = "Interior_lora.safetensors"
14
 
15
+ # ---------- تحميل الموديل الأساسي ----------
16
+ print("Downloading base model...")
17
+ base_path = hf_hub_download(repo_id=MODEL_REPO, filename=BASE_MODEL)
18
 
19
+ print("Loading base model on CPU...")
20
  pipe = StableDiffusionPipeline.from_single_file(
21
+ base_path,
22
  torch_dtype=torch.float32,
23
  safety_checker=None,
24
  requires_safety_checker=False
25
  ).to("cpu")
26
 
27
+ # ---------- تحميل الـ LoRA ودمجه ----------
28
+ print("Downloading LoRA...")
29
+ lora_path = hf_hub_download(repo_id=MODEL_REPO, filename=LORA_FILE)
30
+
31
+ print("Loading LoRA...")
32
+ pipe.load_lora_weights(lora_path)
33
+ pipe.fuse_lora(lora_scale=1.0) # القوة الكاملة للـ LoRA
34
+
35
+ # تحسين الأداء (بدون تغيير الجدول الزمني)
36
  pipe.enable_attention_slicing()
37
  pipe.enable_vae_slicing()
38
 
39
+ print("Model ready 🔥 (Base + LoRA fused)")
40
 
41
+ # ---------- دوال مساعدة ----------
42
+ def to_bytes(img: Image.Image) -> io.BytesIO:
 
 
43
  buf = io.BytesIO()
44
  img.save(buf, format="PNG")
45
  buf.seek(0)
46
  return buf
47
 
48
+ # ---------- Text‑to‑Image (GET بسيط بباراميتر تيكست) ----------
49
+ @app.get("/txt2img")
50
+ def txt2img(
51
+ prompt: str = Query(..., description="الوصف اللي عايز تولده"),
52
+ steps: int = Query(8, ge=1, le=20),
53
+ guidance: float = Query(7.5, ge=1.0, le=20.0),
54
+ height: int = Query(512, ge=256, le=768),
55
+ width: int = Query(512, ge=256, le=768)
56
+ ):
57
  image = pipe(
58
+ prompt=prompt,
59
+ num_inference_steps=steps,
60
+ guidance_scale=guidance,
61
+ height=height,
62
+ width=width
63
  ).images[0]
64
  return StreamingResponse(to_bytes(image), media_type="image/png")
65
 
66
+ # ---------- Image‑to‑Image (POST مع رفع صورة + باراميترات) ----------
67
  @app.post("/img2img")
68
+ async def img2img(
69
+ file: UploadFile = File(...),
70
+ prompt: str = Query(""),
71
+ steps: int = Query(8, ge=1, le=20),
72
+ guidance: float = Query(7.5, ge=1.0, le=20.0),
73
+ strength: float = Query(0.6, ge=0.0, le=1.0),
74
+ height: int = Query(512, ge=256, le=768),
75
+ width: int = Query(512, ge=256, le=768)
76
+ ):
77
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((width, height))
78
  image = pipe(
79
  prompt=prompt,
80
  image=img,
81
+ strength=strength,
82
+ num_inference_steps=steps,
83
+ guidance_scale=guidance
84
  ).images[0]
85
  return StreamingResponse(to_bytes(image), media_type="image/png")