tejani commited on
Commit
da72ab7
·
verified ·
1 Parent(s): 9bb6ee8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import requests
4
+ import io
5
+ import random
6
+ import os
7
+ from PIL import Image
8
+ import base64
9
+
10
+ app = FastAPI()
11
+
12
+ API_URL = f"https://api-inference.huggingface.co/models/{os.getenv("HF_MODEL")}"
13
+ headers = {"Authorization": f"Bearer {os.getenv("HF_TOKEN")}"}
14
+ timeout = 100
15
+
16
+ class ImageRequest(BaseModel):
17
+ prompt: str
18
+ negative_prompt: str = "(deformed, distorted, disfigured), poorly drawn, bad anatomy"
19
+ steps: int = 4
20
+ cfg_scale: float = 7.0
21
+ sampler: str = "DPM++ SDE Karras"
22
+ seed: int = -1
23
+ strength: float = 0.7
24
+
25
+ def query(prompt: str, negative_prompt: str, steps: int, cfg_scale: float,
26
+ sampler: str, seed: int, strength: float):
27
+ if not prompt:
28
+ raise HTTPException(status_code=400, detail="Prompt is required")
29
+
30
+ payload = {
31
+ "inputs": prompt,
32
+ "is_negative": bool(negative_prompt),
33
+ "steps": steps,
34
+ "cfg_scale": cfg_scale,
35
+ "seed": seed if seed != -1 else random.randint(1, 1000000000),
36
+ "strength": strength
37
+ }
38
+
39
+ if negative_prompt:
40
+ payload["negative_prompt"] = negative_prompt
41
+
42
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
43
+
44
+ if response.status_code != 200:
45
+ raise HTTPException(status_code=response.status_code, detail=response.text)
46
+
47
+ image_bytes = response.content
48
+ image = Image.open(io.BytesIO(image_bytes))
49
+
50
+ buffered = io.BytesIO()
51
+ image.save(buffered, format="PNG")
52
+ img_str = base64.b64encode(buffered.getvalue()).decode()
53
+
54
+ return {"image": f"data:image/png;base64,{img_str}"}
55
+
56
+ @app.post("/generate")
57
+ async def generate_image(request: ImageRequest):
58
+ try:
59
+ result = query(
60
+ prompt=request.prompt,
61
+ negative_prompt=request.negative_prompt,
62
+ steps=request.steps,
63
+ cfg_scale=request.cfg_scale,
64
+ sampler=request.sampler,
65
+ seed=request.seed,
66
+ strength=request.strength
67
+ )
68
+ return result
69
+ except HTTPException as e:
70
+ raise e
71
+ except Exception as e:
72
+ raise HTTPException(status_code=500, detail=str(e))