File size: 5,285 Bytes
2f0bcde
 
 
 
67c9c6c
2f0bcde
 
ca16976
2f0bcde
7c9d195
2f0bcde
 
 
 
 
ca16976
2f0bcde
 
 
67c9c6c
2f0bcde
 
7c9d195
2f0bcde
 
 
ca16976
7c9d195
2f0bcde
ca16976
67c9c6c
2f0bcde
 
 
 
 
67c9c6c
 
4b1ddd1
2f0bcde
ca16976
2f0bcde
 
 
ca16976
 
2f0bcde
7c9d195
2f0bcde
4b1ddd1
2f0bcde
 
ca16976
2f0bcde
 
 
 
 
 
 
 
67c9c6c
2f0bcde
 
 
 
 
 
 
7c9d195
67c9c6c
 
2f0bcde
7c9d195
4b1ddd1
 
7c9d195
2f0bcde
4b1ddd1
 
2f0bcde
7c9d195
2f0bcde
7c9d195
 
 
2f0bcde
4b1ddd1
2f0bcde
 
67c9c6c
7c9d195
 
 
 
 
 
2f0bcde
 
 
 
 
7c9d195
2f0bcde
7c9d195
 
2f0bcde
67c9c6c
7c9d195
2f0bcde
 
7c9d195
2f0bcde
 
7c9d195
 
2f0bcde
7c9d195
2f0bcde
7c9d195
 
 
 
 
 
 
 
2f0bcde
7c9d195
 
 
 
 
 
2f0bcde
4b1ddd1
 
2f0bcde
7c9d195
2f0bcde
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import time
import uuid
from typing import Optional
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel

# Make caches and static directories writable
os.environ["HF_HOME"] = "/app/cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
os.makedirs("/app/cache", exist_ok=True)
os.makedirs("/app/static", exist_ok=True)

import torch
from diffusers import StableDiffusionPipeline

# -------- CONFIG --------
MODEL_ID = "runwayml/stable-diffusion-v1-5"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
STATIC_FOLDER = "/app/static"
SPACE_URL = "https://valtry-my-image.hf.space"   # <- set your space URL here
# ------------------------

app = FastAPI(title="Valtry Text→Image API")

# Serve static files publicly at /static/...
app.mount("/static", StaticFiles(directory=STATIC_FOLDER), name="static")

print(f"Loading model {MODEL_ID} on {DEVICE}...")
pipe = StableDiffusionPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
)
pipe = pipe.to(DEVICE)
pipe.safety_checker = getattr(pipe, "safety_checker", None)
print("✅ Model loaded")

class GenerateReq(BaseModel):
    prompt: str
    num_inference_steps: Optional[int] = 25
    guidance_scale: Optional[float] = 7.5
    seed: Optional[int] = None

@app.post("/generate")
async def generate(req: GenerateReq):
    if not req.prompt or not req.prompt.strip():
        return JSONResponse({"error": "prompt is required"}, status_code=400)

    seed = req.seed if req.seed is not None else int(time.time() * 1000) % 2**32
    generator = torch.Generator(device=DEVICE).manual_seed(seed) if DEVICE == "cuda" else None

    try:
        result = pipe(
            req.prompt,
            num_inference_steps=int(req.num_inference_steps),
            guidance_scale=float(req.guidance_scale),
            generator=generator,
        )
    except Exception as e:
        return JSONResponse({"error": f"generation failed: {str(e)}"}, status_code=500)

    image = result.images[0]

    filename = f"img_{int(time.time())}_{uuid.uuid4().hex[:8]}.png"
    file_path = os.path.join(STATIC_FOLDER, filename)
    image.save(file_path)

    # Return an absolute public URL (so external pages can load it)
    public_url = f"{SPACE_URL}/static/{filename}"
    return {"url": public_url, "filename": filename}

# Home page: NOTE -> regular string (NOT an f-string) to avoid Python interpolating JS {..}
@app.get("/", response_class=HTMLResponse)
async def home():
    html = """
    <!doctype html>
    <html>
    <head>
      <meta charset="utf-8"/>
      <title>Valtry — Text → Image</title>
      <style>
        body{font-family:Arial,sans-serif;margin:32px;background:#f7f7f7}
        textarea,input,button{font-size:16px;padding:10px;width:100%;margin-top:8px;box-sizing:border-box}
        img{max-width:100%;border:1px solid #ccc;padding:6px;background:#fff;margin-top:20px}
      </style>
    </head>
    <body>
      <h2>Valtry — Text → Image</h2>
      <textarea id="prompt" rows="3" placeholder="A fantasy castle on a cliff at sunset"></textarea><br>
      <label>Steps (num_inference_steps)</label>
      <input id="steps" type="number" value="25" min="1" max="150"/><br>
      <label>Guidance scale</label>
      <input id="scale" type="number" value="7.5" step="0.1" min="1" max="20"/><br>
      <label>Seed (optional)</label>
      <input id="seed" type="number" placeholder="optional seed"/><br>
      <button onclick="generate()">Generate Image</button>
      <div id="status"></div>
      <div id="result"></div>

      <script>
        async function generate(){
          const prompt = document.getElementById('prompt').value;
          const steps = parseInt(document.getElementById('steps').value || 25);
          const scale = parseFloat(document.getElementById('scale').value || 7.5);
          const seedVal = document.getElementById('seed').value;

          document.getElementById('status').textContent = "⏳ Generating — this may take a bit...";
          document.getElementById('result').innerHTML = "";

          const body = { prompt: prompt, num_inference_steps: steps, guidance_scale: scale };
          if (seedVal) body.seed = parseInt(seedVal);

          try {
            const res = await fetch('/generate', {
              method: 'POST',
              headers: { 'Content-Type': 'application/json' },
              body: JSON.stringify(body)
            });

            if (!res.ok) {
              const txt = await res.text();
              document.getElementById('status').textContent = '❌ Error ' + res.status + ': ' + txt;
              return;
            }

            const data = await res.json();
            document.getElementById('status').textContent = '✅ Done — image below';
            document.getElementById('result').innerHTML = `<img src="${data.url}" alt="generated-image"/>`;
          } catch (err) {
            document.getElementById('status').textContent = '❌ Exception: ' + err.message;
          }
        }
      </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html)

@app.get("/health")
async def health():
    return {"status": "ok"}