ebraam1 commited on
Commit
4a17d02
·
verified ·
1 Parent(s): dd92e9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -62
app.py CHANGED
@@ -1,10 +1,9 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from pydantic import BaseModel
3
- from diffusers import StableDiffusionPipeline
4
  import torch
5
  from PIL import Image
6
  import io
7
- import os
8
  from fastapi.responses import StreamingResponse
9
 
10
  app = FastAPI()
@@ -12,49 +11,31 @@ app = FastAPI()
12
  MODEL_PATH = "Interior.safetensors"
13
  LORA_PATH = "Interior_lora.safetensors"
14
 
15
- # ========================
16
- # ⚡ CPU OPTIMIZATION
17
- # ========================
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
- dtype = torch.float16 if device == "cuda" else torch.float32
20
 
21
- torch.set_num_threads(os.cpu_count())
22
-
23
- print("Loading model...")
 
 
24
 
25
- # ========================
26
- # SINGLE PIPELINE (IMPORTANT FIX)
27
- # ========================
28
- pipe = StableDiffusionPipeline.from_single_file(
29
  MODEL_PATH,
30
- torch_dtype=dtype,
31
  safety_checker=None
32
- ).to(device)
33
 
34
  print("Loading LoRA...")
35
 
36
- pipe.load_lora_weights(LORA_PATH)
37
- pipe.fuse_lora(lora_scale=0.7)
38
-
39
- # ========================
40
- # SPEED BOOSTS
41
- # ========================
42
- pipe.enable_attention_slicing()
43
- pipe.enable_vae_slicing()
44
 
45
- print("Model ready 🔥")
46
 
47
-
48
- # ========================
49
- # REQUEST MODEL
50
- # ========================
51
  class Prompt(BaseModel):
52
  prompt: str
53
 
54
 
55
- # ========================
56
- # IMAGE UTILS
57
- # ========================
58
  def to_bytes(img):
59
  buf = io.BytesIO()
60
  img.save(buf, format="PNG")
@@ -62,49 +43,26 @@ def to_bytes(img):
62
  return buf
63
 
64
 
65
- # ========================
66
- # HEALTH CHECK
67
- # ========================
68
- @app.get("/")
69
- def home():
70
- return {"status": "API is running 🚀"}
71
-
72
-
73
- # ========================
74
- # TXT2IMG (FAST MODE)
75
- # ========================
76
  @app.post("/txt2img")
77
  def generate(data: Prompt):
78
 
79
- image = pipe(
80
  data.prompt,
81
- num_inference_steps=6, # ⚡ أسرع بكتير
82
- guidance_scale=5,
83
- height=256,
84
- width=256
85
  ).images[0]
86
 
87
  return StreamingResponse(to_bytes(image), media_type="image/png")
88
 
89
 
90
- # ========================
91
- # IMG2IMG (FAST MODE)
92
- # ========================
93
  @app.post("/img2img")
94
- async def img2img_api(
95
- file: UploadFile = File(...),
96
- prompt: str = ""
97
- ):
98
 
99
- img = Image.open(io.BytesIO(await file.read())).convert("RGB")
100
- img = img.resize((256, 256)) # ⚡ أسرع بشكل واضح
101
 
102
- image = pipe(
103
  prompt=prompt,
104
  image=img,
105
- strength=0.6,
106
- num_inference_steps=6,
107
- guidance_scale=5
108
  ).images[0]
109
 
110
- return StreamingResponse(to_bytes(image), media_type="image/png")
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from pydantic import BaseModel
3
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
4
  import torch
5
  from PIL import Image
6
  import io
 
7
  from fastapi.responses import StreamingResponse
8
 
9
  app = FastAPI()
 
11
  MODEL_PATH = "Interior.safetensors"
12
  LORA_PATH = "Interior_lora.safetensors"
13
 
14
+ print("Loading base model...")
 
 
 
 
15
 
16
+ txt2img = StableDiffusionPipeline.from_single_file(
17
+ MODEL_PATH,
18
+ torch_dtype=torch.float16,
19
+ safety_checker=None
20
+ ).to("cpu") # هنرجعها GPU لو متاح لاحقًا
21
 
22
+ img2img = StableDiffusionImg2ImgPipeline.from_single_file(
 
 
 
23
  MODEL_PATH,
24
+ torch_dtype=torch.float16,
25
  safety_checker=None
26
+ ).to("cpu")
27
 
28
  print("Loading LoRA...")
29
 
30
+ txt2img.load_lora_weights(LORA_PATH)
31
+ img2img.load_lora_weights(LORA_PATH)
 
 
 
 
 
 
32
 
33
+ print("LoRA loaded 🔥")
34
 
 
 
 
 
35
  class Prompt(BaseModel):
36
  prompt: str
37
 
38
 
 
 
 
39
  def to_bytes(img):
40
  buf = io.BytesIO()
41
  img.save(buf, format="PNG")
 
43
  return buf
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
46
  @app.post("/txt2img")
47
  def generate(data: Prompt):
48
 
49
+ image = txt2img(
50
  data.prompt,
51
+ cross_attention_kwargs={"scale": 0.8}
 
 
 
52
  ).images[0]
53
 
54
  return StreamingResponse(to_bytes(image), media_type="image/png")
55
 
56
 
 
 
 
57
  @app.post("/img2img")
58
+ async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
 
 
 
59
 
60
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((512,512))
 
61
 
62
+ image = img2img(
63
  prompt=prompt,
64
  image=img,
65
+ cross_attention_kwargs={"scale": 0.8}
 
 
66
  ).images[0]
67
 
68
+ return StreamingResponse(to_bytes(image), media_type="image/png")