Spaces:
Paused
Paused
| from typing import Annotated | |
| from fastapi import FastAPI, Path, Query, Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| import torch | |
| from torch import autocast | |
| from diffusers import DiffusionPipeline | |
| from io import BytesIO | |
| import base64 | |
| from os.path import dirname | |
| # class Prompt(BaseModel): | |
| # prompt: str | |
| # steps: Annotated[int, Path(title="No of steps", ge=4, le=10)] = 8 | |
| # guide: Annotated[float, Path(title="Guidance scale", ge=0.5, le=2)] = 0.8 | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_credentials=True, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| device = "cuda" | |
| pipe = DiffusionPipeline.from_pretrained(f'{dirname(__file__)}/cgt2im', | |
| # use_auth_token=auth_token, | |
| # use_safetensors=True | |
| ) | |
| pipe = pipe.to(device, dtype=torch.float16) | |
| # @app.get("/") | |
| # def generate(prompt: str): | |
| # with autocast(device): | |
| # image = pipe( | |
| # prompt=prompt, | |
| # num_inference_steps=8, | |
| # guidance_scale=8.0, | |
| # lcm_origin_steps=50, | |
| # output_type="pil", | |
| # ).images[0] | |
| # # image.save("testimage.png") | |
| # buffer = BytesIO() | |
| # image.save(buffer, format="PNG") | |
| # imgstr = base64.b64encode(buffer.getvalue()) | |
| # return Response(content=imgstr, media_type="image/png") | |
| def generate(prompt: str, | |
| steps: Annotated[int, Query(ge=4, le=10)] = 8, | |
| guide: Annotated[float, Query(ge=0.5, le=2)] = 0.8, | |
| ): | |
| with autocast(device): | |
| image = pipe( | |
| prompt=prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=guide, | |
| lcm_origin_steps=50, | |
| output_type="pil", | |
| ).images[0] | |
| # image.save("testimage.png") | |
| buffer = BytesIO() | |
| image.save(buffer, format="PNG") | |
| imgstr = base64.b64encode(buffer.getvalue()) | |
| return Response(content=imgstr, media_type="image/png") | |
| async def read_home(): | |
| with open("app/static/index.html", "r") as file: | |
| content = file.read() | |
| return HTMLResponse(content=content) | |
| # @app.post("/t2i") | |
| # def generate(prompt: Prompt): | |
| # with autocast(device): | |
| # image = pipe( | |
| # prompt=prompt.prompt, | |
| # num_inference_steps=prompt.steps, | |
| # guidance_scale=prompt.guide, | |
| # lcm_origin_steps=50, | |
| # output_type="pil", | |
| # ).images[0] | |
| # # image.save("testimage.png") | |
| # buffer = BytesIO() | |
| # image.save(buffer, format="PNG") | |
| # imgstr = base64.b64encode(buffer.getvalue()) | |
| # return Response(content=imgstr, media_type="image/png") |