ebraam1 commited on
Commit
2c6d712
·
verified ·
1 Parent(s): e116903

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -5
app.py CHANGED
@@ -4,6 +4,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
4
  import torch
5
  from PIL import Image
6
  import io
 
7
  from fastapi.responses import StreamingResponse
8
 
9
  app = FastAPI()
@@ -11,25 +12,34 @@ app = FastAPI()
11
  MODEL_PATH = "Interior.safetensors"
12
  LORA_PATH = "Interior_lora.safetensors"
13
 
 
 
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  dtype = torch.float16 if device == "cuda" else torch.float32
16
 
 
17
 
18
  print("Loading base model...")
19
 
 
 
 
20
  txt2img = StableDiffusionPipeline.from_single_file(
21
  MODEL_PATH,
22
  torch_dtype=dtype,
23
  safety_checker=None
24
  ).to(device)
25
 
 
 
 
26
  img2img = StableDiffusionImg2ImgPipeline.from_single_file(
27
  MODEL_PATH,
28
  torch_dtype=dtype,
29
  safety_checker=None
30
  ).to(device)
31
 
32
-
33
  print("Loading LoRA...")
34
 
35
  txt2img.load_lora_weights(LORA_PATH)
@@ -38,13 +48,28 @@ img2img.load_lora_weights(LORA_PATH)
38
  txt2img.fuse_lora(lora_scale=0.8)
39
  img2img.fuse_lora(lora_scale=0.8)
40
 
 
 
 
 
 
 
 
 
 
41
  print("LoRA loaded 🔥")
42
 
43
 
 
 
 
44
  class Prompt(BaseModel):
45
  prompt: str
46
 
47
 
 
 
 
48
  def to_bytes(img):
49
  buf = io.BytesIO()
50
  img.save(buf, format="PNG")
@@ -52,19 +77,49 @@ def to_bytes(img):
52
  return buf
53
 
54
 
 
 
 
55
  @app.get("/")
56
  def home():
57
  return {"status": "API is running 🚀"}
58
 
59
 
 
 
 
60
  @app.post("/txt2img")
61
  def generate(data: Prompt):
62
- image = txt2img(data.prompt).images[0]
 
 
 
 
 
 
 
 
63
  return StreamingResponse(to_bytes(image), media_type="image/png")
64
 
65
 
 
 
 
66
  @app.post("/img2img")
67
- async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
68
- img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((512, 512))
69
- image = img2img(prompt=prompt, image=img).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  return StreamingResponse(to_bytes(image), media_type="image/png")
 
4
  import torch
5
  from PIL import Image
6
  import io
7
+ import os
8
  from fastapi.responses import StreamingResponse
9
 
10
  app = FastAPI()
 
12
  MODEL_PATH = "Interior.safetensors"
13
  LORA_PATH = "Interior_lora.safetensors"
14
 
15
+ # ========================
16
+ # CPU Optimization
17
+ # ========================
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  dtype = torch.float16 if device == "cuda" else torch.float32
20
 
21
+ torch.set_num_threads(os.cpu_count()) # تسريع CPU execution
22
 
23
  print("Loading base model...")
24
 
25
+ # ========================
26
+ # TXT2IMG PIPELINE
27
+ # ========================
28
  txt2img = StableDiffusionPipeline.from_single_file(
29
  MODEL_PATH,
30
  torch_dtype=dtype,
31
  safety_checker=None
32
  ).to(device)
33
 
34
+ # ========================
35
+ # IMG2IMG PIPELINE
36
+ # ========================
37
  img2img = StableDiffusionImg2ImgPipeline.from_single_file(
38
  MODEL_PATH,
39
  torch_dtype=dtype,
40
  safety_checker=None
41
  ).to(device)
42
 
 
43
  print("Loading LoRA...")
44
 
45
  txt2img.load_lora_weights(LORA_PATH)
 
48
  txt2img.fuse_lora(lora_scale=0.8)
49
  img2img.fuse_lora(lora_scale=0.8)
50
 
51
+ # ========================
52
+ # SPEED BOOSTS
53
+ # ========================
54
+ txt2img.enable_attention_slicing()
55
+ txt2img.enable_vae_slicing()
56
+
57
+ img2img.enable_attention_slicing()
58
+ img2img.enable_vae_slicing()
59
+
60
  print("LoRA loaded 🔥")
61
 
62
 
63
+ # ========================
64
+ # REQUEST MODEL
65
+ # ========================
66
  class Prompt(BaseModel):
67
  prompt: str
68
 
69
 
70
+ # ========================
71
+ # IMAGE UTILS
72
+ # ========================
73
  def to_bytes(img):
74
  buf = io.BytesIO()
75
  img.save(buf, format="PNG")
 
77
  return buf
78
 
79
 
80
+ # ========================
81
+ # HEALTH CHECK
82
+ # ========================
83
  @app.get("/")
84
  def home():
85
  return {"status": "API is running 🚀"}
86
 
87
 
88
+ # ========================
89
+ # TXT2IMG ENDPOINT (FAST MODE)
90
+ # ========================
91
  @app.post("/txt2img")
92
  def generate(data: Prompt):
93
+
94
+ image = txt2img(
95
+ data.prompt,
96
+ num_inference_steps=10, # ⚡ أسرع حاجة
97
+ guidance_scale=5,
98
+ height=384,
99
+ width=384
100
+ ).images[0]
101
+
102
  return StreamingResponse(to_bytes(image), media_type="image/png")
103
 
104
 
105
+ # ========================
106
+ # IMG2IMG ENDPOINT (FAST MODE)
107
+ # ========================
108
  @app.post("/img2img")
109
+ async def img2img_api(
110
+ file: UploadFile = File(...),
111
+ prompt: str = ""
112
+ ):
113
+
114
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
115
+ img = img.resize((384, 384)) # ⚡ تسريع مهم جدًا
116
+
117
+ image = img2img(
118
+ prompt=prompt,
119
+ image=img,
120
+ strength=0.6,
121
+ num_inference_steps=10,
122
+ guidance_scale=5
123
+ ).images[0]
124
+
125
  return StreamingResponse(to_bytes(image), media_type="image/png")