ebraam1 commited on
Commit
764f1e8
·
verified ·
1 Parent(s): 9668c21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -68
app.py CHANGED
@@ -1,91 +1,70 @@
1
- from fastapi import FastAPI, UploadFile, File, Query
2
- from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
3
- from huggingface_hub import hf_hub_download
4
- from fastapi.responses import StreamingResponse
5
- from PIL import Image
6
  import torch
 
7
  import io
 
8
 
9
  app = FastAPI()
10
 
11
- MODEL_REPO = "ebraam1/interior-sd-models"
12
- BASE_MODEL = "Interior.safetensors"
13
- LORA_FILE = "Interior_lora.safetensors"
14
 
15
- # تحميل الموديل الأساسي
16
- print("Downloading base model...")
17
- base_path = hf_hub_download(repo_id=MODEL_REPO, filename=BASE_MODEL)
18
 
19
- print("Loading base model on CPU...")
20
- pipe = StableDiffusionPipeline.from_single_file(
21
- base_path,
22
- torch_dtype=torch.float32,
23
- safety_checker=None,
24
- requires_safety_checker=False
25
- ).to("cpu")
26
 
27
- # أفضل Scheduler للجودة العالية
28
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # تحميل LoRA ودمجه
31
- print("Downloading LoRA...")
32
- lora_path = hf_hub_download(repo_id=MODEL_REPO, filename=LORA_FILE)
33
 
34
  print("Loading LoRA...")
35
- pipe.load_lora_weights(lora_path)
36
- pipe.fuse_lora(lora_scale=1.0)
37
 
38
- # تحسينات الذاكرة
39
- pipe.enable_attention_slicing()
40
- pipe.enable_vae_slicing()
 
 
41
 
42
- print("Model ready 🔥 (Max quality mode - Landscape)")
43
 
44
- def to_bytes(img: Image.Image) -> io.BytesIO:
 
 
 
 
 
45
  buf = io.BytesIO()
46
  img.save(buf, format="PNG")
47
  buf.seek(0)
48
  return buf
49
 
50
- # ================== Text-to-Image (GET) ==================
51
- @app.get("/txt2img")
52
- def txt2img(
53
- prompt: str = Query(..., description="الوصف"),
54
- negative_prompt: str = Query("", description="العناصر اللي عايز تتجنبها"),
55
- steps: int = Query(20, ge=1, le=30),
56
- guidance: float = Query(9.0, ge=1.0, le=20.0),
57
- height: int = Query(512, ge=256, le=768), # أصبح أقل (عرضي)
58
- width: int = Query(768, ge=256, le=768) # أصبح أكبر
59
- ):
60
- image = pipe(
61
- prompt=prompt,
62
- negative_prompt=negative_prompt or None,
63
- num_inference_steps=steps,
64
- guidance_scale=guidance,
65
- height=height,
66
- width=width
67
- ).images[0]
68
  return StreamingResponse(to_bytes(image), media_type="image/png")
69
 
70
- # ================== Image-to-Image (POST) ==================
71
  @app.post("/img2img")
72
- async def img2img(
73
- file: UploadFile = File(...),
74
- prompt: str = Query(""),
75
- negative_prompt: str = Query(""),
76
- steps: int = Query(20, ge=1, le=30),
77
- guidance: float = Query(9.0, ge=1.0, le=20.0),
78
- strength: float = Query(0.6, ge=0.0, le=1.0),
79
- height: int = Query(512, ge=256, le=768),
80
- width: int = Query(768, ge=256, le=768)
81
- ):
82
- img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((width, height))
83
- image = pipe(
84
- prompt=prompt,
85
- negative_prompt=negative_prompt or None,
86
- image=img,
87
- strength=strength,
88
- num_inference_steps=steps,
89
- guidance_scale=guidance
90
- ).images[0]
91
  return StreamingResponse(to_bytes(image), media_type="image/png")
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from pydantic import BaseModel
3
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
 
 
4
  import torch
5
+ from PIL import Image
6
  import io
7
+ from fastapi.responses import StreamingResponse
8
 
9
  app = FastAPI()
10
 
11
+ MODEL_PATH = "Interior.safetensors"
12
+ LORA_PATH = "Interior_lora.safetensors"
 
13
 
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ dtype = torch.float16 if device == "cuda" else torch.float32
 
16
 
 
 
 
 
 
 
 
17
 
18
+ print("Loading base model...")
19
+
20
+ txt2img = StableDiffusionPipeline.from_single_file(
21
+ MODEL_PATH,
22
+ torch_dtype=dtype,
23
+ safety_checker=None
24
+ ).to(device)
25
+
26
+ img2img = StableDiffusionImg2ImgPipeline.from_single_file(
27
+ MODEL_PATH,
28
+ torch_dtype=dtype,
29
+ safety_checker=None
30
+ ).to(device)
31
 
 
 
 
32
 
33
  print("Loading LoRA...")
 
 
34
 
35
+ txt2img.load_lora_weights(LORA_PATH)
36
+ img2img.load_lora_weights(LORA_PATH)
37
+
38
+ txt2img.fuse_lora(lora_scale=0.8)
39
+ img2img.fuse_lora(lora_scale=0.8)
40
 
41
+ print("LoRA loaded 🔥")
42
 
43
+
44
+ class Prompt(BaseModel):
45
+ prompt: str
46
+
47
+
48
+ def to_bytes(img):
49
  buf = io.BytesIO()
50
  img.save(buf, format="PNG")
51
  buf.seek(0)
52
  return buf
53
 
54
+
55
+ @app.get("/")
56
+ def home():
57
+ return {"status": "API is running 🚀"}
58
+
59
+
60
+ @app.post("/txt2img")
61
+ def generate(data: Prompt):
62
+ image = txt2img(data.prompt).images[0]
 
 
 
 
 
 
 
 
 
63
  return StreamingResponse(to_bytes(image), media_type="image/png")
64
 
65
+
66
  @app.post("/img2img")
67
+ async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
68
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((512, 512))
69
+ image = img2img(prompt=prompt, image=img).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  return StreamingResponse(to_bytes(image), media_type="image/png")