ebraam1 commited on
Commit
e5559c2
ยท
verified ยท
1 Parent(s): 4e704d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -39
app.py CHANGED
@@ -10,8 +10,13 @@ import io
10
  app = FastAPI()
11
 
12
  # =========================
13
- # LoRA
14
  # =========================
 
 
 
 
 
15
  LORA_PATH = hf_hub_download(
16
  repo_id="ebraam1/interior-sd-models",
17
  filename="Interior_lora.safetensors"
@@ -20,82 +25,67 @@ LORA_PATH = hf_hub_download(
20
  print("Loading base model...")
21
 
22
  # =========================
23
- # BASE PIPELINE (CPU SAFE)
24
  # =========================
25
- pipe_txt = StableDiffusionPipeline.from_pretrained(
26
- "runwayml/stable-diffusion-v1-5",
27
- torch_dtype=torch.float32,
28
  safety_checker=None
29
- ).to("cpu")
30
 
31
- pipe_img = StableDiffusionImg2ImgPipeline.from_pretrained(
32
- "runwayml/stable-diffusion-v1-5",
33
- torch_dtype=torch.float32,
34
  safety_checker=None
35
  ).to("cpu")
36
 
37
  print("Loading LoRA...")
38
 
39
- pipe_txt.load_lora_weights(LORA_PATH)
40
- pipe_img.load_lora_weights(LORA_PATH)
41
-
42
- pipe_txt.fuse_lora(lora_scale=0.8)
43
- pipe_img.fuse_lora(lora_scale=0.8)
44
-
45
- # โšก speed boosts
46
- pipe_txt.enable_attention_slicing()
47
- pipe_txt.enable_vae_slicing()
48
-
49
- pipe_img.enable_attention_slicing()
50
- pipe_img.enable_vae_slicing()
51
-
52
- print("Model ready ๐Ÿ”ฅ")
53
 
 
54
 
 
 
55
  # =========================
56
  class Prompt(BaseModel):
57
  prompt: str
58
 
59
-
 
 
60
  def to_bytes(img):
61
  buf = io.BytesIO()
62
  img.save(buf, format="PNG")
63
  buf.seek(0)
64
  return buf
65
 
66
-
67
  # =========================
68
- # TXT2IMG
69
  # =========================
70
  @app.post("/txt2img")
71
  def generate(data: Prompt):
72
 
73
- image = pipe_txt(
74
  data.prompt,
75
- num_inference_steps=5,
76
- guidance_scale=5,
77
- height=256,
78
- width=256
79
  ).images[0]
80
 
81
  return StreamingResponse(to_bytes(image), media_type="image/png")
82
 
83
-
84
  # =========================
85
- # IMG2IMG
86
  # =========================
87
  @app.post("/img2img")
88
  async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
89
 
90
- img = Image.open(io.BytesIO(await file.read())).convert("RGB")
91
- img = img.resize((256, 256))
92
 
93
- image = pipe_img(
94
  prompt=prompt,
95
  image=img,
96
- strength=0.6,
97
- num_inference_steps=5,
98
- guidance_scale=5
99
  ).images[0]
100
 
101
  return StreamingResponse(to_bytes(image), media_type="image/png")
 
10
  app = FastAPI()
11
 
12
  # =========================
13
+ # ุชุญู…ูŠู„ ุงู„ู…ูˆุฏูŠู„ุงุช ู…ู† HuggingFace
14
  # =========================
15
+ MODEL_PATH = hf_hub_download(
16
+ repo_id="ebraam1/interior-sd-models",
17
+ filename="Interior.safetensors"
18
+ )
19
+
20
  LORA_PATH = hf_hub_download(
21
  repo_id="ebraam1/interior-sd-models",
22
  filename="Interior_lora.safetensors"
 
25
  print("Loading base model...")
26
 
27
  # =========================
28
+ # Load Stable Diffusion
29
  # =========================
30
+ txt2img = StableDiffusionPipeline.from_single_file(
31
+ MODEL_PATH,
32
+ torch_dtype=torch.float16,
33
  safety_checker=None
34
+ ).to("cpu") # ุบูŠู‘ุฑู‡ุง ู„ู€ "cuda" ู„ูˆ GPU ู…ุชุงุญ
35
 
36
+ img2img = StableDiffusionImg2ImgPipeline.from_single_file(
37
+ MODEL_PATH,
38
+ torch_dtype=torch.float16,
39
  safety_checker=None
40
  ).to("cpu")
41
 
42
  print("Loading LoRA...")
43
 
44
+ txt2img.load_lora_weights(LORA_PATH)
45
+ img2img.load_lora_weights(LORA_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ print("LoRA loaded ๐Ÿ”ฅ")
48
 
49
+ # =========================
50
+ # API Schema
51
  # =========================
52
  class Prompt(BaseModel):
53
  prompt: str
54
 
55
+ # =========================
56
+ # Helper: PIL โ†’ Bytes
57
+ # =========================
58
  def to_bytes(img):
59
  buf = io.BytesIO()
60
  img.save(buf, format="PNG")
61
  buf.seek(0)
62
  return buf
63
 
 
64
  # =========================
65
+ # TEXT โ†’ IMAGE
66
  # =========================
67
  @app.post("/txt2img")
68
  def generate(data: Prompt):
69
 
70
+ image = txt2img(
71
  data.prompt,
72
+ cross_attention_kwargs={"scale": 0.8}
 
 
 
73
  ).images[0]
74
 
75
  return StreamingResponse(to_bytes(image), media_type="image/png")
76
 
 
77
  # =========================
78
+ # IMAGE โ†’ IMAGE
79
  # =========================
80
  @app.post("/img2img")
81
  async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
82
 
83
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((512, 512))
 
84
 
85
+ image = img2img(
86
  prompt=prompt,
87
  image=img,
88
+ cross_attention_kwargs={"scale": 0.8}
 
 
89
  ).images[0]
90
 
91
  return StreamingResponse(to_bytes(image), media_type="image/png")