Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from fastapi import FastAPI, Request, UploadFile, File, Query
|
| 2 |
from PIL import Image, ImageDraw, ImageFont
|
| 3 |
import torch
|
| 4 |
-
from diffusers import StableDiffusionImg2ImgPipeline
|
| 5 |
from starlette.responses import StreamingResponse
|
| 6 |
from fastapi.responses import FileResponse
|
| 7 |
import io
|
|
@@ -12,11 +12,20 @@ device = "cuda"
|
|
| 12 |
# , torch_dtype=torch.float16
|
| 13 |
device1 = "cpu"
|
| 14 |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device1)
|
|
|
|
| 15 |
|
| 16 |
@app.get("/")
|
| 17 |
async def root():
|
| 18 |
return {"message": "Welcome to the Creating Ad Template With Stable Diffusion API!"}
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
@app.post("/generate_ad")
|
| 21 |
async def generate_ad(prompt: str,hex_code:str, button_color:str, punchline_color:str, image_file: UploadFile, logo: UploadFile):
|
| 22 |
|
|
|
|
| 1 |
from fastapi import FastAPI, Request, UploadFile, File, Query
|
| 2 |
from PIL import Image, ImageDraw, ImageFont
|
| 3 |
import torch
|
| 4 |
+
from diffusers import StableDiffusionImg2ImgPipeline, DiffusionPipeline
|
| 5 |
from starlette.responses import StreamingResponse
|
| 6 |
from fastapi.responses import FileResponse
|
| 7 |
import io
|
|
|
|
| 12 |
# , torch_dtype=torch.float16
|
| 13 |
device1 = "cpu"
|
| 14 |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device1)
|
| 15 |
+
pipe1 = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
| 16 |
|
| 17 |
@app.get("/")
|
| 18 |
async def root():
|
| 19 |
return {"message": "Welcome to the Creating Ad Template With Stable Diffusion API!"}
|
| 20 |
+
|
| 21 |
+
@app.get("/generate", response_class=StreamingResponse)
|
| 22 |
+
async def generate_image(prompt: str = Query(..., description="Text prompt for image generation")):
|
| 23 |
+
image = pipe1(prompt).images[0]
|
| 24 |
+
image_data = io.BytesIO()
|
| 25 |
+
image.save(image_data, format="PNG")
|
| 26 |
+
image_data.seek(0)
|
| 27 |
+
return StreamingResponse(image_data, media_type="image/png")
|
| 28 |
+
|
| 29 |
@app.post("/generate_ad")
|
| 30 |
async def generate_ad(prompt: str,hex_code:str, button_color:str, punchline_color:str, image_file: UploadFile, logo: UploadFile):
|
| 31 |
|