ebraam1 commited on
Commit
22e9227
·
verified ·
1 Parent(s): c575fec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -61
app.py CHANGED
@@ -1,110 +1,126 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from pydantic import BaseModel
3
- from diffusers import StableDiffusionPipeline
4
- import torch
 
5
  from PIL import Image
 
 
6
  import io
7
- import os
8
- from fastapi.responses import StreamingResponse
9
 
10
  app = FastAPI()
11
 
12
- MODEL_PATH = "Interior.safetensors"
13
- LORA_PATH = "Interior_lora.safetensors"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # ========================
16
- # ⚡ CPU OPTIMIZATION
17
- # ========================
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
- dtype = torch.float16 if device == "cuda" else torch.float32
20
 
21
- torch.set_num_threads(os.cpu_count())
22
 
23
- print("Loading model...")
24
 
25
- # ========================
26
- # SINGLE PIPELINE (IMPORTANT FIX)
27
- # ========================
28
- pipe = StableDiffusionPipeline.from_single_file(
29
  MODEL_PATH,
30
- torch_dtype=dtype,
31
  safety_checker=None
32
- ).to(device)
33
 
34
  print("Loading LoRA...")
35
 
36
- pipe.load_lora_weights(LORA_PATH)
37
- pipe.fuse_lora(lora_scale=0.7)
 
 
 
 
 
38
 
39
- # ========================
40
- # SPEED BOOSTS
41
- # ========================
42
- pipe.enable_attention_slicing()
43
- pipe.enable_vae_slicing()
44
 
45
- print("Model ready 🔥")
46
 
 
 
 
47
 
48
- # ========================
49
- # REQUEST MODEL
50
- # ========================
51
  class Prompt(BaseModel):
52
  prompt: str
53
 
 
 
 
54
 
55
- # ========================
56
- # IMAGE UTILS
57
- # ========================
58
  def to_bytes(img):
59
  buf = io.BytesIO()
60
  img.save(buf, format="PNG")
61
  buf.seek(0)
62
  return buf
63
 
 
 
 
 
 
 
 
 
 
64
 
65
- # ========================
66
- # HEALTH CHECK
67
- # ========================
68
- @app.get("/")
69
- def home():
70
- return {"status": "API is running 🚀"}
71
 
72
 
73
- # ========================
74
- # TXT2IMG (FAST MODE)
75
- # ========================
76
  @app.post("/txt2img")
77
  def generate(data: Prompt):
78
 
79
- image = pipe(
80
  data.prompt,
81
- num_inference_steps=6, # ⚡ أسرع بكتير
82
- guidance_scale=5,
83
- height=256,
84
- width=256
85
  ).images[0]
86
 
87
  return StreamingResponse(to_bytes(image), media_type="image/png")
88
 
 
 
 
89
 
90
- # ========================
91
- # IMG2IMG (FAST MODE)
92
- # ========================
93
  @app.post("/img2img")
94
- async def img2img_api(
95
- file: UploadFile = File(...),
96
- prompt: str = ""
97
- ):
98
 
99
- img = Image.open(io.BytesIO(await file.read())).convert("RGB")
100
- img = img.resize((256, 256)) # ⚡ أسرع بشكل واضح
101
 
102
- image = pipe(
 
103
  prompt=prompt,
104
  image=img,
105
- strength=0.6,
106
- num_inference_steps=6,
107
- guidance_scale=5
108
  ).images[0]
109
 
110
- return StreamingResponse(to_bytes(image), media_type="image/png")
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from pydantic import BaseModel
3
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
4
+ from huggingface_hub import hf_hub_download
5
+ from fastapi.responses import StreamingResponse
6
  from PIL import Image
7
+ import torch
8
+
9
  import io
10
+
11
+
12
 
13
  app = FastAPI()
14
 
15
+ # =========================
16
+ # تحميل الموديلات من HuggingFace
17
+ # =========================
18
+ MODEL_PATH = hf_hub_download(
19
+ repo_id="ebraam1/interior-sd-models",
20
+ filename="Interior.safetensors"
21
+ )
22
+
23
+ LORA_PATH = hf_hub_download(
24
+ repo_id="ebraam1/interior-sd-models",
25
+ filename="Interior_lora.safetensors"
26
+ )
27
+
28
+ print("Loading base model...")
29
+
30
+ # =========================
31
+ # Load Stable Diffusion
32
+ # =========================
33
+ txt2img = StableDiffusionPipeline.from_single_file(
34
+ MODEL_PATH,
35
+ torch_dtype=torch.float16,
36
+ safety_checker=None
37
+ ).to("cpu") # غيّرها لـ "cuda" لو GPU متاح
38
+
39
+ img2img = StableDiffusionImg2ImgPipeline.from_single_file(
40
+
41
 
 
 
 
 
 
42
 
 
43
 
 
44
 
 
 
 
 
45
  MODEL_PATH,
46
+ torch_dtype=torch.float16,
47
  safety_checker=None
48
+ ).to("cpu")
49
 
50
  print("Loading LoRA...")
51
 
52
+ txt2img.load_lora_weights(LORA_PATH)
53
+ img2img.load_lora_weights(LORA_PATH)
54
+
55
+
56
+
57
+
58
+
59
 
 
 
 
 
 
60
 
61
+ print("LoRA loaded 🔥")
62
 
63
+ # =========================
64
+ # API Schema
65
+ # =========================
66
 
 
 
 
67
  class Prompt(BaseModel):
68
  prompt: str
69
 
70
+ # =========================
71
+ # Helper: PIL → Bytes
72
+ # =========================
73
 
 
 
 
74
  def to_bytes(img):
75
  buf = io.BytesIO()
76
  img.save(buf, format="PNG")
77
  buf.seek(0)
78
  return buf
79
 
80
+ # =========================
81
+ # TEXT → IMAGE
82
+ # =========================
83
+
84
+
85
+
86
+
87
+
88
+
89
 
 
 
 
 
 
 
90
 
91
 
 
 
 
92
  @app.post("/txt2img")
93
  def generate(data: Prompt):
94
 
95
+ image = txt2img(
96
  data.prompt,
97
+ cross_attention_kwargs={"scale": 0.8}
98
+
99
+
100
+
101
  ).images[0]
102
 
103
  return StreamingResponse(to_bytes(image), media_type="image/png")
104
 
105
+ # =========================
106
+ # IMAGE → IMAGE
107
+ # =========================
108
 
 
 
 
109
  @app.post("/img2img")
110
+ async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
111
+
112
+
113
+
114
 
115
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((512, 512))
 
116
 
117
+
118
+ image = img2img(
119
  prompt=prompt,
120
  image=img,
121
+ cross_attention_kwargs={"scale": 0.8}
122
+
123
+
124
  ).images[0]
125
 
126
+ return StreamingResponse(to_bytes(image), media_type="image/png")