Valtry commited on
Commit
2f0bcde
·
verified ·
1 Parent(s): 4b1ddd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -44
app.py CHANGED
@@ -1,64 +1,152 @@
 
 
 
 
 
 
1
  from fastapi import FastAPI, Request
2
- from fastapi.responses import HTMLResponse
 
3
  from pydantic import BaseModel
4
- from diffusers import StableDiffusionPipeline
 
 
 
 
 
 
 
5
  import torch
6
- import base64
7
- from io import BytesIO
8
- from PIL import Image
 
 
 
 
 
 
9
 
10
- app = FastAPI()
 
11
 
12
- # Load model
13
- model_id = "runwayml/stable-diffusion-v1-5"
14
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cpu") # use "cuda" if GPU
 
 
 
 
 
 
 
15
 
16
- # Request body schema
17
- class GenerateRequest(BaseModel):
18
  prompt: str
 
 
 
19
 
20
- # API route for text-to-image
21
  @app.post("/generate")
22
- def generate_image(request: GenerateRequest):
23
- image = pipe(request.prompt).images[0]
 
24
 
25
- # Save to base64
26
- buffer = BytesIO()
27
- image.save(buffer, format="PNG")
28
- img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
29
- return {"image_base64": img_str}
30
 
31
- # Simple HTML test page
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  @app.get("/", response_class=HTMLResponse)
33
  async def home():
34
  return """
 
35
  <html>
36
  <head>
37
- <title>Valtry Text-to-Image</title>
 
 
 
 
 
 
 
 
38
  </head>
39
- <body style="font-family: Arial; text-align: center; padding: 50px;">
40
- <h2>🖌️ Valtry Text-to-Image Generator</h2>
41
- <input id="prompt" type="text" placeholder="Enter your prompt..." style="width: 300px; padding: 5px;" />
42
- <button onclick="generate()">Generate</button>
43
- <div id="result" style="margin-top: 20px;"></div>
44
- <script>
45
- async function generate() {
46
- const prompt = document.getElementById("prompt").value;
47
- document.getElementById("result").innerHTML = "⏳ Generating...";
48
- const res = await fetch('/generate', {
49
- method: 'POST',
50
- headers: { 'Content-Type': 'application/json' },
51
- body: JSON.stringify({ prompt })
52
- });
53
- if (!res.ok) {
54
- document.getElementById("result").innerHTML = "❌ Error: " + res.status;
55
- return;
56
- }
57
- const data = await res.json();
58
- document.getElementById("result").innerHTML =
59
- `<img src="data:image/png;base64,${data.image_base64}" style="max-width:100%;"/>`;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  }
61
- </script>
 
 
 
 
 
 
 
62
  </body>
63
  </html>
64
- """
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import uuid
4
+ from io import BytesIO
5
+ from typing import Optional
6
+
7
  from fastapi import FastAPI, Request
8
+ from fastapi.responses import HTMLResponse, JSONResponse
9
+ from fastapi.staticfiles import StaticFiles
10
  from pydantic import BaseModel
11
+
12
+ # Put the HF / transformers cache into a writable folder
13
+ os.environ["HF_HOME"] = "/app/cache"
14
+ os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
15
+ os.makedirs("/app/cache", exist_ok=True)
16
+ os.makedirs("/app/static", exist_ok=True)
17
+
18
+ # Import after setting env
19
  import torch
20
+ from diffusers import StableDiffusionPipeline
21
+
22
+ # -------- CONFIG --------
23
+ MODEL_ID = "runwayml/stable-diffusion-v1-5" # change if you prefer another model
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+ STATIC_FOLDER = "/app/static"
26
+ # ------------------------
27
+
28
+ app = FastAPI(title="Valtry Text→Image API")
29
 
30
+ # Mount static folder so generated images are publicly accessible at /static/...
31
+ app.mount("/static", StaticFiles(directory=STATIC_FOLDER), name="static")
32
 
33
+ # Load model once at startup
34
+ print("Loading model", MODEL_ID, "to device", DEVICE)
35
+ # If your model is gated, set use_auth_token=os.getenv("HF_TOKEN") in from_pretrained(...)
36
+ pipe = StableDiffusionPipeline.from_pretrained(
37
+ MODEL_ID,
38
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
39
+ )
40
+ pipe = pipe.to(DEVICE)
41
+ pipe.safety_checker = pipe.safety_checker if hasattr(pipe, "safety_checker") else None
42
+ print("Model loaded")
43
 
44
+ class GenerateReq(BaseModel):
 
45
  prompt: str
46
+ num_inference_steps: Optional[int] = 25
47
+ guidance_scale: Optional[float] = 7.5
48
+ seed: Optional[int] = None
49
 
 
50
  @app.post("/generate")
51
+ async def generate(req: GenerateReq):
52
+ if not req.prompt or not req.prompt.strip():
53
+ return JSONResponse({"error": "prompt is required"}, status_code=400)
54
 
55
+ seed = req.seed if req.seed is not None else int(time.time() * 1000) % 2**32
56
+ generator = torch.Generator(device=DEVICE).manual_seed(seed) if DEVICE == "cuda" else None
 
 
 
57
 
58
+ # Generate image
59
+ try:
60
+ result = pipe(
61
+ req.prompt,
62
+ num_inference_steps=int(req.num_inference_steps),
63
+ guidance_scale=float(req.guidance_scale),
64
+ generator=generator,
65
+ )
66
+ except Exception as e:
67
+ return JSONResponse({"error": f"model generation failed: {str(e)}"}, status_code=500)
68
+
69
+ image = result.images[0]
70
+
71
+ # Save image file
72
+ filename = f"img_{int(time.time())}_{uuid.uuid4().hex[:8]}.png"
73
+ file_path = os.path.join(STATIC_FOLDER, filename)
74
+ image.save(file_path)
75
+
76
+ # Return the public URL (relative)
77
+ return {"url": f"/static/{filename}", "filename": filename}
78
+
79
+ # Simple home page for quick testing in browser
80
  @app.get("/", response_class=HTMLResponse)
81
  async def home():
82
  return """
83
+ <!doctype html>
84
  <html>
85
  <head>
86
+ <meta charset="utf-8"/>
87
+ <title>Valtry Text→Image</title>
88
+ <style>
89
+ body{font-family:Arial,sans-serif;margin:32px;background:#f7f7f7}
90
+ input, button, textarea{font-size:16px;padding:10px;width:100%;box-sizing:border-box;margin-top:8px}
91
+ #result{margin-top:20px}
92
+ img{max-width:100%;border:1px solid #ccc;padding:6px;background:#fff}
93
+ label{font-weight:600}
94
+ </style>
95
  </head>
96
+ <body>
97
+ <h2>Valtry TextImage</h2>
98
+ <label>Prompt</label>
99
+ <textarea id="prompt" rows="3" placeholder="A fantasy castle on a cliff at sunset"></textarea>
100
+ <label>Max steps (num_inference_steps)</label>
101
+ <input id="steps" type="number" value="25" min="1" max="150"/>
102
+ <label>Guidance scale</label>
103
+ <input id="scale" type="number" value="7.5" step="0.1" min="1" max="20"/>
104
+ <label>Seed (optional)</label>
105
+ <input id="seed" type="number" placeholder="leave empty for random"/>
106
+ <button onclick="generate()">Generate Image</button>
107
+ <div id="status"></div>
108
+ <div id="result"></div>
109
+
110
+ <script>
111
+ async function generate(){
112
+ const prompt = document.getElementById('prompt').value;
113
+ const steps = parseInt(document.getElementById('steps').value || 25);
114
+ const scale = parseFloat(document.getElementById('scale').value || 7.5);
115
+ const seedVal = document.getElementById('seed').value;
116
+ document.getElementById('status').textContent = " Generating — this may take a bit...";
117
+ document.getElementById('result').innerHTML = "";
118
+
119
+ const body = {
120
+ prompt: prompt,
121
+ num_inference_steps: steps,
122
+ guidance_scale: scale
123
+ };
124
+ if (seedVal) body.seed = parseInt(seedVal);
125
+
126
+ try {
127
+ const res = await fetch('/generate', {
128
+ method: 'POST',
129
+ headers: {'Content-Type': 'application/json'},
130
+ body: JSON.stringify(body)
131
+ });
132
+ if (!res.ok) {
133
+ const txt = await res.text();
134
+ document.getElementById('status').textContent = '❌ Error ' + res.status + ': ' + txt;
135
+ return;
136
  }
137
+ const data = await res.json();
138
+ document.getElementById('status').textContent = '✅ Done — image below';
139
+ document.getElementById('result').innerHTML = `<img src="${data.url}" alt="generated-image"/>`;
140
+ } catch (err) {
141
+ document.getElementById('status').textContent = '❌ Exception: ' + err.message;
142
+ }
143
+ }
144
+ </script>
145
  </body>
146
  </html>
147
+ """
148
+
149
+ # optional health route
150
+ @app.get("/health")
151
+ async def health():
152
+ return {"status": "ok"}