ebraam1 commited on
Commit
8512e9b
·
verified ·
1 Parent(s): 4a17d02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -19
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from pydantic import BaseModel
3
- from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
4
  import torch
5
  from PIL import Image
6
  import io
@@ -11,27 +11,32 @@ app = FastAPI()
11
  MODEL_PATH = "Interior.safetensors"
12
  LORA_PATH = "Interior_lora.safetensors"
13
 
14
- print("Loading base model...")
15
 
16
- txt2img = StableDiffusionPipeline.from_single_file(
 
 
 
17
  MODEL_PATH,
18
- torch_dtype=torch.float16,
19
- safety_checker=None
20
- ).to("cpu") # هنرجعها GPU لو متاح لاحقًا
21
-
22
- img2img = StableDiffusionImg2ImgPipeline.from_single_file(
23
- MODEL_PATH,
24
- torch_dtype=torch.float16,
25
  safety_checker=None
26
  ).to("cpu")
27
 
28
  print("Loading LoRA...")
29
 
30
- txt2img.load_lora_weights(LORA_PATH)
31
- img2img.load_lora_weights(LORA_PATH)
32
 
33
- print("LoRA loaded 🔥")
 
 
34
 
 
 
 
 
 
 
35
  class Prompt(BaseModel):
36
  prompt: str
37
 
@@ -43,26 +48,38 @@ def to_bytes(img):
43
  return buf
44
 
45
 
 
 
 
46
  @app.post("/txt2img")
47
  def generate(data: Prompt):
48
 
49
- image = txt2img(
50
  data.prompt,
51
- cross_attention_kwargs={"scale": 0.8}
 
 
 
52
  ).images[0]
53
 
54
  return StreamingResponse(to_bytes(image), media_type="image/png")
55
 
56
 
 
 
 
57
  @app.post("/img2img")
58
  async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
59
 
60
- img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((512,512))
 
61
 
62
- image = img2img(
63
  prompt=prompt,
64
  image=img,
65
- cross_attention_kwargs={"scale": 0.8}
 
 
66
  ).images[0]
67
 
68
- return StreamingResponse(to_bytes(image), media_type="image/png")
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from pydantic import BaseModel
3
+ from diffusers import StableDiffusionPipeline
4
  import torch
5
  from PIL import Image
6
  import io
 
11
  MODEL_PATH = "Interior.safetensors"
12
  LORA_PATH = "Interior_lora.safetensors"
13
 
14
+ print("Loading model...")
15
 
16
+ # ========================
17
+ # ⚡ SAFE CPU CONFIG
18
+ # ========================
19
+ pipe = StableDiffusionPipeline.from_single_file(
20
  MODEL_PATH,
21
+ torch_dtype=torch.float32, # ✔ مهم جدًا
 
 
 
 
 
 
22
  safety_checker=None
23
  ).to("cpu")
24
 
25
  print("Loading LoRA...")
26
 
27
+ pipe.load_lora_weights(LORA_PATH)
28
+ pipe.fuse_lora(lora_scale=0.8)
29
 
30
+ # speed boost
31
+ pipe.enable_attention_slicing()
32
+ pipe.enable_vae_slicing()
33
 
34
+ print("Model ready 🔥")
35
+
36
+
37
+ # ========================
38
+ # REQUEST MODEL
39
+ # ========================
40
  class Prompt(BaseModel):
41
  prompt: str
42
 
 
48
  return buf
49
 
50
 
51
+ # ========================
52
+ # TXT2IMG
53
+ # ========================
54
  @app.post("/txt2img")
55
  def generate(data: Prompt):
56
 
57
+ image = pipe(
58
  data.prompt,
59
+ num_inference_steps=6, # ⚡ سريع
60
+ guidance_scale=5,
61
+ height=256,
62
+ width=256
63
  ).images[0]
64
 
65
  return StreamingResponse(to_bytes(image), media_type="image/png")
66
 
67
 
68
+ # ========================
69
+ # IMG2IMG
70
+ # ========================
71
  @app.post("/img2img")
72
  async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
73
 
74
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
75
+ img = img.resize((256, 256)) # ⚡ أسرع
76
 
77
+ image = pipe(
78
  prompt=prompt,
79
  image=img,
80
+ strength=0.6,
81
+ num_inference_steps=6,
82
+ guidance_scale=5
83
  ).images[0]
84
 
85
+ return StreamingResponse(to_bytes(image), media_type="image/png")