Valtry commited on
Commit
67c9c6c
·
verified ·
1 Parent(s): 6709abd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -57
app.py CHANGED
@@ -1,45 +1,41 @@
1
  import os
2
  import time
3
  import uuid
4
- from io import BytesIO
5
  from typing import Optional
6
-
7
- from fastapi import FastAPI, Request
8
  from fastapi.responses import HTMLResponse, JSONResponse
9
  from fastapi.staticfiles import StaticFiles
10
  from pydantic import BaseModel
11
 
12
- # Put the HF / transformers cache into a writable folder
13
  os.environ["HF_HOME"] = "/app/cache"
14
  os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
15
  os.makedirs("/app/cache", exist_ok=True)
16
  os.makedirs("/app/static", exist_ok=True)
17
 
18
- # Import after setting env
19
  import torch
20
  from diffusers import StableDiffusionPipeline
21
 
22
  # -------- CONFIG --------
23
- MODEL_ID = "runwayml/stable-diffusion-v1-5" # change if you prefer another model
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
  STATIC_FOLDER = "/app/static"
 
26
  # ------------------------
27
 
28
  app = FastAPI(title="Valtry Text→Image API")
29
 
30
- # Mount static folder so generated images are publicly accessible at /static/...
31
  app.mount("/static", StaticFiles(directory=STATIC_FOLDER), name="static")
32
 
33
- # Load model once at startup
34
- print("Loading model", MODEL_ID, "to device", DEVICE)
35
- # If your model is gated, set use_auth_token=os.getenv("HF_TOKEN") in from_pretrained(...)
36
  pipe = StableDiffusionPipeline.from_pretrained(
37
  MODEL_ID,
38
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
39
  )
40
  pipe = pipe.to(DEVICE)
41
- pipe.safety_checker = pipe.safety_checker if hasattr(pipe, "safety_checker") else None
42
- print("Model loaded")
43
 
44
  class GenerateReq(BaseModel):
45
  prompt: str
@@ -49,13 +45,12 @@ class GenerateReq(BaseModel):
49
 
50
  @app.post("/generate")
51
  async def generate(req: GenerateReq):
52
- if not req.prompt or not req.prompt.strip():
53
  return JSONResponse({"error": "prompt is required"}, status_code=400)
54
 
55
  seed = req.seed if req.seed is not None else int(time.time() * 1000) % 2**32
56
  generator = torch.Generator(device=DEVICE).manual_seed(seed) if DEVICE == "cuda" else None
57
 
58
- # Generate image
59
  try:
60
  result = pipe(
61
  req.prompt,
@@ -64,89 +59,77 @@ async def generate(req: GenerateReq):
64
  generator=generator,
65
  )
66
  except Exception as e:
67
- return JSONResponse({"error": f"model generation failed: {str(e)}"}, status_code=500)
68
 
69
  image = result.images[0]
70
 
71
- # Save image file
72
  filename = f"img_{int(time.time())}_{uuid.uuid4().hex[:8]}.png"
73
  file_path = os.path.join(STATIC_FOLDER, filename)
74
  image.save(file_path)
75
 
76
- # Return the public URL (relative)
77
- return {"url": f"/static/{filename}", "filename": filename}
 
78
 
79
- # Simple home page for quick testing in browser
80
  @app.get("/", response_class=HTMLResponse)
81
  async def home():
82
- return """
83
  <!doctype html>
84
  <html>
85
  <head>
86
  <meta charset="utf-8"/>
87
- <title>Valtry Text→Image</title>
88
  <style>
89
- body{font-family:Arial,sans-serif;margin:32px;background:#f7f7f7}
90
- input, button, textarea{font-size:16px;padding:10px;width:100%;box-sizing:border-box;margin-top:8px}
91
- #result{margin-top:20px}
92
- img{max-width:100%;border:1px solid #ccc;padding:6px;background:#fff}
93
- label{font-weight:600}
94
  </style>
95
  </head>
96
  <body>
97
  <h2>Valtry — Text → Image</h2>
98
- <label>Prompt</label>
99
- <textarea id="prompt" rows="3" placeholder="A fantasy castle on a cliff at sunset"></textarea>
100
- <label>Max steps (num_inference_steps)</label>
101
  <input id="steps" type="number" value="25" min="1" max="150"/>
102
- <label>Guidance scale</label>
103
  <input id="scale" type="number" value="7.5" step="0.1" min="1" max="20"/>
104
- <label>Seed (optional)</label>
105
- <input id="seed" type="number" placeholder="leave empty for random"/>
106
  <button onclick="generate()">Generate Image</button>
107
  <div id="status"></div>
108
  <div id="result"></div>
109
 
110
  <script>
111
- async function generate(){
112
  const prompt = document.getElementById('prompt').value;
113
- const steps = parseInt(document.getElementById('steps').value || 25);
114
- const scale = parseFloat(document.getElementById('scale').value || 7.5);
115
  const seedVal = document.getElementById('seed').value;
116
- document.getElementById('status').textContent = "⏳ Generating — this may take a bit...";
 
117
  document.getElementById('result').innerHTML = "";
118
 
119
- const body = {
120
- prompt: prompt,
121
- num_inference_steps: steps,
122
- guidance_scale: scale
123
- };
124
  if (seedVal) body.seed = parseInt(seedVal);
125
 
126
- try {
127
- const res = await fetch('/generate', {
128
  method: 'POST',
129
- headers: {'Content-Type': 'application/json'},
130
  body: JSON.stringify(body)
131
- });
132
- if (!res.ok) {
133
- const txt = await res.text();
134
- document.getElementById('status').textContent = '❌ Error ' + res.status + ': ' + txt;
135
- return;
136
- }
137
  const data = await res.json();
138
- document.getElementById('status').textContent = '✅ Done — image below';
139
- document.getElementById('result').innerHTML = `<img src="${data.url}" alt="generated-image"/>`;
140
- } catch (err) {
141
- document.getElementById('status').textContent = '❌ Exception: ' + err.message;
142
- }
143
- }
 
 
 
 
144
  </script>
145
  </body>
146
  </html>
147
  """
148
 
149
- # optional health route
150
  @app.get("/health")
151
  async def health():
152
  return {"status": "ok"}
 
1
  import os
2
  import time
3
  import uuid
 
4
  from typing import Optional
5
+ from fastapi import FastAPI
 
6
  from fastapi.responses import HTMLResponse, JSONResponse
7
  from fastapi.staticfiles import StaticFiles
8
  from pydantic import BaseModel
9
 
10
+ # Set cache directories to writable paths
11
  os.environ["HF_HOME"] = "/app/cache"
12
  os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
13
  os.makedirs("/app/cache", exist_ok=True)
14
  os.makedirs("/app/static", exist_ok=True)
15
 
 
16
  import torch
17
  from diffusers import StableDiffusionPipeline
18
 
19
  # -------- CONFIG --------
20
+ MODEL_ID = "runwayml/stable-diffusion-v1-5"
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
  STATIC_FOLDER = "/app/static"
23
+ SPACE_URL = "https://valtry-my-image.hf.space" # <-- CHANGE THIS to your Space's URL
24
  # ------------------------
25
 
26
  app = FastAPI(title="Valtry Text→Image API")
27
 
28
+ # Serve static folder
29
  app.mount("/static", StaticFiles(directory=STATIC_FOLDER), name="static")
30
 
31
+ print(f"Loading model {MODEL_ID} on {DEVICE}...")
 
 
32
  pipe = StableDiffusionPipeline.from_pretrained(
33
  MODEL_ID,
34
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
35
  )
36
  pipe = pipe.to(DEVICE)
37
+ pipe.safety_checker = getattr(pipe, "safety_checker", None)
38
+ print("Model loaded")
39
 
40
  class GenerateReq(BaseModel):
41
  prompt: str
 
45
 
46
  @app.post("/generate")
47
  async def generate(req: GenerateReq):
48
+ if not req.prompt.strip():
49
  return JSONResponse({"error": "prompt is required"}, status_code=400)
50
 
51
  seed = req.seed if req.seed is not None else int(time.time() * 1000) % 2**32
52
  generator = torch.Generator(device=DEVICE).manual_seed(seed) if DEVICE == "cuda" else None
53
 
 
54
  try:
55
  result = pipe(
56
  req.prompt,
 
59
  generator=generator,
60
  )
61
  except Exception as e:
62
+ return JSONResponse({"error": f"generation failed: {str(e)}"}, status_code=500)
63
 
64
  image = result.images[0]
65
 
 
66
  filename = f"img_{int(time.time())}_{uuid.uuid4().hex[:8]}.png"
67
  file_path = os.path.join(STATIC_FOLDER, filename)
68
  image.save(file_path)
69
 
70
+ # Full public URL so it loads in browser
71
+ public_url = f"{SPACE_URL}/static/{filename}"
72
+ return {"url": public_url, "filename": filename}
73
 
 
74
  @app.get("/", response_class=HTMLResponse)
75
  async def home():
76
+ return f"""
77
  <!doctype html>
78
  <html>
79
  <head>
80
  <meta charset="utf-8"/>
81
+ <title>Valtry Text→Image</title>
82
  <style>
83
+ body{{font-family:Arial,sans-serif;margin:32px;background:#f7f7f7}}
84
+ input, button, textarea{{font-size:16px;padding:10px;width:100%;margin-top:8px}}
85
+ img{{max-width:100%;border:1px solid #ccc;padding:6px;background:#fff;margin-top:20px}}
 
 
86
  </style>
87
  </head>
88
  <body>
89
  <h2>Valtry — Text → Image</h2>
90
+ <textarea id="prompt" rows="3" placeholder="A fantasy castle on a cliff at sunset"></textarea><br>
 
 
91
  <input id="steps" type="number" value="25" min="1" max="150"/>
 
92
  <input id="scale" type="number" value="7.5" step="0.1" min="1" max="20"/>
93
+ <input id="seed" type="number" placeholder="optional seed"/>
 
94
  <button onclick="generate()">Generate Image</button>
95
  <div id="status"></div>
96
  <div id="result"></div>
97
 
98
  <script>
99
+ async function generate(){{
100
  const prompt = document.getElementById('prompt').value;
101
+ const steps = parseInt(document.getElementById('steps').value);
102
+ const scale = parseFloat(document.getElementById('scale').value);
103
  const seedVal = document.getElementById('seed').value;
104
+
105
+ document.getElementById('status').textContent = "⏳ Generating...";
106
  document.getElementById('result').innerHTML = "";
107
 
108
+ const body = {{ prompt, num_inference_steps: steps, guidance_scale: scale }};
 
 
 
 
109
  if (seedVal) body.seed = parseInt(seedVal);
110
 
111
+ try {{
112
+ const res = await fetch('/generate', {{
113
  method: 'POST',
114
+ headers: {{ 'Content-Type': 'application/json' }},
115
  body: JSON.stringify(body)
116
+ }});
 
 
 
 
 
117
  const data = await res.json();
118
+ if (res.ok) {{
119
+ document.getElementById('status').textContent = " Done";
120
+ document.getElementById('result').innerHTML = `<img src="${data.url}" alt="Generated image"/>`;
121
+ }} else {{
122
+ document.getElementById('status').textContent = "❌ Error: " + data.error;
123
+ }}
124
+ }} catch (err) {{
125
+ document.getElementById('status').textContent = "❌ " + err.message;
126
+ }}
127
+ }}
128
  </script>
129
  </body>
130
  </html>
131
  """
132
 
 
133
  @app.get("/health")
134
  async def health():
135
  return {"status": "ok"}