farukbera commited on
Commit
be5df3f
·
1 Parent(s): cb289b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi import FastAPI, Request, UploadFile, File, Query
2
  from PIL import Image, ImageDraw, ImageFont
3
  import torch
4
- from diffusers import StableDiffusionImg2ImgPipeline
5
  from starlette.responses import StreamingResponse
6
  from fastapi.responses import FileResponse
7
  import io
@@ -12,11 +12,20 @@ device = "cuda"
12
  # , torch_dtype=torch.float16
13
  device1 = "cpu"
14
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device1)
 
15
 
16
  @app.get("/")
17
  async def root():
18
  return {"message": "Welcome to the Creating Ad Template With Stable Diffusion API!"}
19
-
 
 
 
 
 
 
 
 
20
  @app.post("/generate_ad")
21
  async def generate_ad(prompt: str,hex_code:str, button_color:str, punchline_color:str, image_file: UploadFile, logo: UploadFile):
22
 
 
1
  from fastapi import FastAPI, Request, UploadFile, File, Query
2
  from PIL import Image, ImageDraw, ImageFont
3
  import torch
4
+ from diffusers import StableDiffusionImg2ImgPipeline, DiffusionPipeline
5
  from starlette.responses import StreamingResponse
6
  from fastapi.responses import FileResponse
7
  import io
 
12
  # , torch_dtype=torch.float16
13
  device1 = "cpu"
14
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device1)
15
+ pipe1 = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
16
 
17
  @app.get("/")
18
  async def root():
19
  return {"message": "Welcome to the Creating Ad Template With Stable Diffusion API!"}
20
+
21
+ @app.get("/generate", response_class=StreamingResponse)
22
+ async def generate_image(prompt: str = Query(..., description="Text prompt for image generation")):
23
+ image = pipe1(prompt).images[0]
24
+ image_data = io.BytesIO()
25
+ image.save(image_data, format="PNG")
26
+ image_data.seek(0)
27
+ return StreamingResponse(image_data, media_type="image/png")
28
+
29
  @app.post("/generate_ad")
30
  async def generate_ad(prompt: str,hex_code:str, button_color:str, punchline_color:str, image_file: UploadFile, logo: UploadFile):
31