My-Image / app.py
Valtry's picture
Update app.py
7c9d195 verified
import os
import time
import uuid
from typing import Optional
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
# Make caches and static directories writable
os.environ["HF_HOME"] = "/app/cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
os.makedirs("/app/cache", exist_ok=True)
os.makedirs("/app/static", exist_ok=True)
import torch
from diffusers import StableDiffusionPipeline
# -------- CONFIG --------
MODEL_ID = "runwayml/stable-diffusion-v1-5"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
STATIC_FOLDER = "/app/static"
SPACE_URL = "https://valtry-my-image.hf.space" # <- set your space URL here
# ------------------------
app = FastAPI(title="Valtry Text→Image API")
# Serve static files publicly at /static/...
app.mount("/static", StaticFiles(directory=STATIC_FOLDER), name="static")
print(f"Loading model {MODEL_ID} on {DEVICE}...")
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
)
pipe = pipe.to(DEVICE)
pipe.safety_checker = getattr(pipe, "safety_checker", None)
print("✅ Model loaded")
class GenerateReq(BaseModel):
prompt: str
num_inference_steps: Optional[int] = 25
guidance_scale: Optional[float] = 7.5
seed: Optional[int] = None
@app.post("/generate")
async def generate(req: GenerateReq):
if not req.prompt or not req.prompt.strip():
return JSONResponse({"error": "prompt is required"}, status_code=400)
seed = req.seed if req.seed is not None else int(time.time() * 1000) % 2**32
generator = torch.Generator(device=DEVICE).manual_seed(seed) if DEVICE == "cuda" else None
try:
result = pipe(
req.prompt,
num_inference_steps=int(req.num_inference_steps),
guidance_scale=float(req.guidance_scale),
generator=generator,
)
except Exception as e:
return JSONResponse({"error": f"generation failed: {str(e)}"}, status_code=500)
image = result.images[0]
filename = f"img_{int(time.time())}_{uuid.uuid4().hex[:8]}.png"
file_path = os.path.join(STATIC_FOLDER, filename)
image.save(file_path)
# Return an absolute public URL (so external pages can load it)
public_url = f"{SPACE_URL}/static/{filename}"
return {"url": public_url, "filename": filename}
# Home page: NOTE -> regular string (NOT an f-string) to avoid Python interpolating JS {..}
@app.get("/", response_class=HTMLResponse)
async def home():
html = """
<!doctype html>
<html>
<head>
<meta charset="utf-8"/>
<title>Valtry — Text → Image</title>
<style>
body{font-family:Arial,sans-serif;margin:32px;background:#f7f7f7}
textarea,input,button{font-size:16px;padding:10px;width:100%;margin-top:8px;box-sizing:border-box}
img{max-width:100%;border:1px solid #ccc;padding:6px;background:#fff;margin-top:20px}
</style>
</head>
<body>
<h2>Valtry — Text → Image</h2>
<textarea id="prompt" rows="3" placeholder="A fantasy castle on a cliff at sunset"></textarea><br>
<label>Steps (num_inference_steps)</label>
<input id="steps" type="number" value="25" min="1" max="150"/><br>
<label>Guidance scale</label>
<input id="scale" type="number" value="7.5" step="0.1" min="1" max="20"/><br>
<label>Seed (optional)</label>
<input id="seed" type="number" placeholder="optional seed"/><br>
<button onclick="generate()">Generate Image</button>
<div id="status"></div>
<div id="result"></div>
<script>
async function generate(){
const prompt = document.getElementById('prompt').value;
const steps = parseInt(document.getElementById('steps').value || 25);
const scale = parseFloat(document.getElementById('scale').value || 7.5);
const seedVal = document.getElementById('seed').value;
document.getElementById('status').textContent = "⏳ Generating — this may take a bit...";
document.getElementById('result').innerHTML = "";
const body = { prompt: prompt, num_inference_steps: steps, guidance_scale: scale };
if (seedVal) body.seed = parseInt(seedVal);
try {
const res = await fetch('/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(body)
});
if (!res.ok) {
const txt = await res.text();
document.getElementById('status').textContent = '❌ Error ' + res.status + ': ' + txt;
return;
}
const data = await res.json();
document.getElementById('status').textContent = '✅ Done — image below';
document.getElementById('result').innerHTML = `<img src="${data.url}" alt="generated-image"/>`;
} catch (err) {
document.getElementById('status').textContent = '❌ Exception: ' + err.message;
}
}
</script>
</body>
</html>
"""
return HTMLResponse(content=html)
@app.get("/health")
async def health():
return {"status": "ok"}