ebraam1 commited on
Commit
e116903
·
verified ·
1 Parent(s): 52a3b75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -11,27 +11,36 @@ app = FastAPI()
11
  MODEL_PATH = "Interior.safetensors"
12
  LORA_PATH = "Interior_lora.safetensors"
13
 
 
 
 
 
14
  print("Loading base model...")
15
 
16
  txt2img = StableDiffusionPipeline.from_single_file(
17
  MODEL_PATH,
18
- torch_dtype=torch.float16,
19
  safety_checker=None
20
- ).to("cpu") # هنرجعها GPU لو متاح لاحقًا
21
 
22
  img2img = StableDiffusionImg2ImgPipeline.from_single_file(
23
  MODEL_PATH,
24
- torch_dtype=torch.float16,
25
  safety_checker=None
26
- ).to("cpu")
 
27
 
28
  print("Loading LoRA...")
29
 
30
  txt2img.load_lora_weights(LORA_PATH)
31
  img2img.load_lora_weights(LORA_PATH)
32
 
 
 
 
33
  print("LoRA loaded 🔥")
34
 
 
35
  class Prompt(BaseModel):
36
  prompt: str
37
 
@@ -43,26 +52,19 @@ def to_bytes(img):
43
  return buf
44
 
45
 
46
- @app.post("/txt2img")
47
- def generate(data: Prompt):
 
48
 
49
- image = txt2img(
50
- data.prompt,
51
- cross_attention_kwargs={"scale": 0.8}
52
- ).images[0]
53
 
 
 
 
54
  return StreamingResponse(to_bytes(image), media_type="image/png")
55
 
56
 
57
  @app.post("/img2img")
58
  async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
59
-
60
- img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((512,512))
61
-
62
- image = img2img(
63
- prompt=prompt,
64
- image=img,
65
- cross_attention_kwargs={"scale": 0.8}
66
- ).images[0]
67
-
68
  return StreamingResponse(to_bytes(image), media_type="image/png")
 
11
  MODEL_PATH = "Interior.safetensors"
12
  LORA_PATH = "Interior_lora.safetensors"
13
 
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ dtype = torch.float16 if device == "cuda" else torch.float32
16
+
17
+
18
  print("Loading base model...")
19
 
20
  txt2img = StableDiffusionPipeline.from_single_file(
21
  MODEL_PATH,
22
+ torch_dtype=dtype,
23
  safety_checker=None
24
+ ).to(device)
25
 
26
  img2img = StableDiffusionImg2ImgPipeline.from_single_file(
27
  MODEL_PATH,
28
+ torch_dtype=dtype,
29
  safety_checker=None
30
+ ).to(device)
31
+
32
 
33
  print("Loading LoRA...")
34
 
35
  txt2img.load_lora_weights(LORA_PATH)
36
  img2img.load_lora_weights(LORA_PATH)
37
 
38
+ txt2img.fuse_lora(lora_scale=0.8)
39
+ img2img.fuse_lora(lora_scale=0.8)
40
+
41
  print("LoRA loaded 🔥")
42
 
43
+
44
  class Prompt(BaseModel):
45
  prompt: str
46
 
 
52
  return buf
53
 
54
 
55
+ @app.get("/")
56
+ def home():
57
+ return {"status": "API is running 🚀"}
58
 
 
 
 
 
59
 
60
+ @app.post("/txt2img")
61
+ def generate(data: Prompt):
62
+ image = txt2img(data.prompt).images[0]
63
  return StreamingResponse(to_bytes(image), media_type="image/png")
64
 
65
 
66
  @app.post("/img2img")
67
  async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
68
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((512, 512))
69
+ image = img2img(prompt=prompt, image=img).images[0]
 
 
 
 
 
 
 
70
  return StreamingResponse(to_bytes(image), media_type="image/png")