|
|
|
|
|
from fastapi import FastAPI, Response |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline |
|
|
from pydantic import BaseModel |
|
|
import io |
|
|
import base64 |
|
|
|
|
|
|
|
|
class ImageRequest(BaseModel): |
|
|
prompt: str |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
"segmind/tiny-sd", |
|
|
torch_dtype=torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
pipe.enable_attention_slicing() |
|
|
print("Model loaded successfully!") |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/generate-image") |
|
|
async def generate_image(request: ImageRequest): |
|
|
try: |
|
|
prompt = request.prompt |
|
|
print(f"Generating image for prompt: {prompt}") |
|
|
|
|
|
|
|
|
|
|
|
image = pipe(prompt, num_inference_steps=25, guidance_scale=7.5).images[0] |
|
|
|
|
|
print("Image generated.") |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
image.save(buffer, format="PNG") |
|
|
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
|
|
|
|
|
|
return {"image_data": img_str} |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"An error occurred during image generation: {e}") |
|
|
return Response(content=f"An error occurred: {e}", status_code=500) |
|
|
|
|
|
@app.get("/") |
|
|
def read_root(): |
|
|
return {"Status": "API is running"} |