ebraam1 commited on
Commit
dd92e9e
·
verified ·
1 Parent(s): 2c6d712

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -36
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from pydantic import BaseModel
3
- from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
4
  import torch
5
  from PIL import Image
6
  import io
@@ -13,28 +13,19 @@ MODEL_PATH = "Interior.safetensors"
13
  LORA_PATH = "Interior_lora.safetensors"
14
 
15
  # ========================
16
- # CPU Optimization
17
  # ========================
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  dtype = torch.float16 if device == "cuda" else torch.float32
20
 
21
- torch.set_num_threads(os.cpu_count()) # تسريع CPU execution
22
 
23
- print("Loading base model...")
24
 
25
  # ========================
26
- # TXT2IMG PIPELINE
27
  # ========================
28
- txt2img = StableDiffusionPipeline.from_single_file(
29
- MODEL_PATH,
30
- torch_dtype=dtype,
31
- safety_checker=None
32
- ).to(device)
33
-
34
- # ========================
35
- # IMG2IMG PIPELINE
36
- # ========================
37
- img2img = StableDiffusionImg2ImgPipeline.from_single_file(
38
  MODEL_PATH,
39
  torch_dtype=dtype,
40
  safety_checker=None
@@ -42,22 +33,16 @@ img2img = StableDiffusionImg2ImgPipeline.from_single_file(
42
 
43
  print("Loading LoRA...")
44
 
45
- txt2img.load_lora_weights(LORA_PATH)
46
- img2img.load_lora_weights(LORA_PATH)
47
-
48
- txt2img.fuse_lora(lora_scale=0.8)
49
- img2img.fuse_lora(lora_scale=0.8)
50
 
51
  # ========================
52
- # SPEED BOOSTS
53
  # ========================
54
- txt2img.enable_attention_slicing()
55
- txt2img.enable_vae_slicing()
56
-
57
- img2img.enable_attention_slicing()
58
- img2img.enable_vae_slicing()
59
 
60
- print("LoRA loaded 🔥")
61
 
62
 
63
  # ========================
@@ -86,24 +71,24 @@ def home():
86
 
87
 
88
  # ========================
89
- # TXT2IMG ENDPOINT (FAST MODE)
90
  # ========================
91
  @app.post("/txt2img")
92
  def generate(data: Prompt):
93
 
94
- image = txt2img(
95
  data.prompt,
96
- num_inference_steps=10, # ⚡ أسرع حاجة
97
  guidance_scale=5,
98
- height=384,
99
- width=384
100
  ).images[0]
101
 
102
  return StreamingResponse(to_bytes(image), media_type="image/png")
103
 
104
 
105
  # ========================
106
- # IMG2IMG ENDPOINT (FAST MODE)
107
  # ========================
108
  @app.post("/img2img")
109
  async def img2img_api(
@@ -112,13 +97,13 @@ async def img2img_api(
112
  ):
113
 
114
  img = Image.open(io.BytesIO(await file.read())).convert("RGB")
115
- img = img.resize((384, 384)) # ⚡ تسريع مهم جدًا
116
 
117
- image = img2img(
118
  prompt=prompt,
119
  image=img,
120
  strength=0.6,
121
- num_inference_steps=10,
122
  guidance_scale=5
123
  ).images[0]
124
 
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from pydantic import BaseModel
3
+ from diffusers import StableDiffusionPipeline
4
  import torch
5
  from PIL import Image
6
  import io
 
13
  LORA_PATH = "Interior_lora.safetensors"
14
 
15
  # ========================
16
+ #CPU OPTIMIZATION
17
  # ========================
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  dtype = torch.float16 if device == "cuda" else torch.float32
20
 
21
+ torch.set_num_threads(os.cpu_count())
22
 
23
+ print("Loading model...")
24
 
25
  # ========================
26
+ # SINGLE PIPELINE (IMPORTANT FIX)
27
  # ========================
28
+ pipe = StableDiffusionPipeline.from_single_file(
 
 
 
 
 
 
 
 
 
29
  MODEL_PATH,
30
  torch_dtype=dtype,
31
  safety_checker=None
 
33
 
34
  print("Loading LoRA...")
35
 
36
+ pipe.load_lora_weights(LORA_PATH)
37
+ pipe.fuse_lora(lora_scale=0.7)
 
 
 
38
 
39
  # ========================
40
+ # SPEED BOOSTS
41
  # ========================
42
+ pipe.enable_attention_slicing()
43
+ pipe.enable_vae_slicing()
 
 
 
44
 
45
+ print("Model ready 🔥")
46
 
47
 
48
  # ========================
 
71
 
72
 
73
  # ========================
74
+ # TXT2IMG (FAST MODE)
75
  # ========================
76
  @app.post("/txt2img")
77
  def generate(data: Prompt):
78
 
79
+ image = pipe(
80
  data.prompt,
81
+ num_inference_steps=6, # ⚡ أسرع بكتير
82
  guidance_scale=5,
83
+ height=256,
84
+ width=256
85
  ).images[0]
86
 
87
  return StreamingResponse(to_bytes(image), media_type="image/png")
88
 
89
 
90
  # ========================
91
+ # IMG2IMG (FAST MODE)
92
  # ========================
93
  @app.post("/img2img")
94
  async def img2img_api(
 
97
  ):
98
 
99
  img = Image.open(io.BytesIO(await file.read())).convert("RGB")
100
+ img = img.resize((256, 256)) # ⚡ أسرع بشكل واضح
101
 
102
+ image = pipe(
103
  prompt=prompt,
104
  image=img,
105
  strength=0.6,
106
+ num_inference_steps=6,
107
  guidance_scale=5
108
  ).images[0]
109