| from flask import request |
| from diffusers import StableDiffusionPipeline |
| import torch |
| from fastapi import FastAPI, Response |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| app = FastAPI() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_credentials=True, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| model_id = "runwayml/stable-diffusion-v1-5" |
| pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
| pipe = pipe.to("cpu") |
|
|
| def dummy(images, **kwargs): |
| return images, False |
|
|
| pipe.safety_checker = dummy |
|
|
| @app.route('/') |
| def generate_image(): |
| prompt = request.args.get('prompt') |
| image = pipe(prompt).images[0] |
| |
| image_data = image.tobytes().hex() |
|
|
| return {'image_data': image_data} |
|
|