| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | import requests |
| | import random |
| | import os |
| | from fastapi.responses import Response |
| |
|
| | app = FastAPI() |
| |
|
| | API_URL = f'https://api-inference.huggingface.co/models/{os.getenv("HF_MODEL")}' |
| | headers = {'Authorization': f'Bearer {os.getenv("HF_TOKEN")}'} |
| | timeout = 100 |
| |
|
| | class ImageRequest(BaseModel): |
| | prompt: str |
| | negative_prompt: str = "(deformed, distorted, disfigured), poorly drawn, bad anatomy" |
| | steps: int = 4 |
| | cfg_scale: float = 7.0 |
| | sampler: str = "DPM++ 2M Karras" |
| | seed: int = -1 |
| | strength: float = 0.7 |
| |
|
| | def query(prompt: str, negative_prompt: str, steps: int, cfg_scale: float, |
| | sampler: str, seed: int, strength: float): |
| | if not prompt: |
| | raise HTTPException(status_code=400, detail="Prompt is required") |
| |
|
| | payload = { |
| | "inputs": prompt, |
| | "is_negative": bool(negative_prompt), |
| | "steps": steps, |
| | "cfg_scale": cfg_scale, |
| | "seed": seed if seed != -1 else random.randint(1, 1000000000), |
| | "strength": strength |
| | } |
| |
|
| | if negative_prompt: |
| | payload["negative_prompt"] = negative_prompt |
| |
|
| | response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout) |
| | |
| | if response.status_code != 200: |
| | raise HTTPException(status_code=response.status_code, detail=response.text) |
| | |
| | return response.content |
| |
|
| | @app.post("/generate") |
| | async def generate_image(request: ImageRequest): |
| | try: |
| | raw_data = query( |
| | prompt=request.prompt, |
| | negative_prompt=request.negative_prompt, |
| | steps=request.steps, |
| | cfg_scale=request.cfg_scale, |
| | sampler=request.sampler, |
| | seed=request.seed, |
| | strength=request.strength |
| | ) |
| | return Response(content=raw_data, media_type="application/octet-stream") |
| | except HTTPException as e: |
| | raise e |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |