|
|
import os |
|
|
import time |
|
|
import uuid |
|
|
from typing import Optional |
|
|
from fastapi import FastAPI |
|
|
from fastapi.responses import HTMLResponse, JSONResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "/app/cache" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/app/cache" |
|
|
os.makedirs("/app/cache", exist_ok=True) |
|
|
os.makedirs("/app/static", exist_ok=True) |
|
|
|
|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
|
|
|
MODEL_ID = "runwayml/stable-diffusion-v1-5" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
STATIC_FOLDER = "/app/static" |
|
|
SPACE_URL = "https://valtry-my-image.hf.space" |
|
|
|
|
|
|
|
|
app = FastAPI(title="Valtry Text→Image API") |
|
|
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory=STATIC_FOLDER), name="static") |
|
|
|
|
|
print(f"Loading model {MODEL_ID} on {DEVICE}...") |
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
|
) |
|
|
pipe = pipe.to(DEVICE) |
|
|
pipe.safety_checker = getattr(pipe, "safety_checker", None) |
|
|
print("✅ Model loaded") |
|
|
|
|
|
class GenerateReq(BaseModel): |
|
|
prompt: str |
|
|
num_inference_steps: Optional[int] = 25 |
|
|
guidance_scale: Optional[float] = 7.5 |
|
|
seed: Optional[int] = None |
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate(req: GenerateReq): |
|
|
if not req.prompt or not req.prompt.strip(): |
|
|
return JSONResponse({"error": "prompt is required"}, status_code=400) |
|
|
|
|
|
seed = req.seed if req.seed is not None else int(time.time() * 1000) % 2**32 |
|
|
generator = torch.Generator(device=DEVICE).manual_seed(seed) if DEVICE == "cuda" else None |
|
|
|
|
|
try: |
|
|
result = pipe( |
|
|
req.prompt, |
|
|
num_inference_steps=int(req.num_inference_steps), |
|
|
guidance_scale=float(req.guidance_scale), |
|
|
generator=generator, |
|
|
) |
|
|
except Exception as e: |
|
|
return JSONResponse({"error": f"generation failed: {str(e)}"}, status_code=500) |
|
|
|
|
|
image = result.images[0] |
|
|
|
|
|
filename = f"img_{int(time.time())}_{uuid.uuid4().hex[:8]}.png" |
|
|
file_path = os.path.join(STATIC_FOLDER, filename) |
|
|
image.save(file_path) |
|
|
|
|
|
|
|
|
public_url = f"{SPACE_URL}/static/{filename}" |
|
|
return {"url": public_url, "filename": filename} |
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
async def home(): |
|
|
html = """ |
|
|
<!doctype html> |
|
|
<html> |
|
|
<head> |
|
|
<meta charset="utf-8"/> |
|
|
<title>Valtry — Text → Image</title> |
|
|
<style> |
|
|
body{font-family:Arial,sans-serif;margin:32px;background:#f7f7f7} |
|
|
textarea,input,button{font-size:16px;padding:10px;width:100%;margin-top:8px;box-sizing:border-box} |
|
|
img{max-width:100%;border:1px solid #ccc;padding:6px;background:#fff;margin-top:20px} |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
<h2>Valtry — Text → Image</h2> |
|
|
<textarea id="prompt" rows="3" placeholder="A fantasy castle on a cliff at sunset"></textarea><br> |
|
|
<label>Steps (num_inference_steps)</label> |
|
|
<input id="steps" type="number" value="25" min="1" max="150"/><br> |
|
|
<label>Guidance scale</label> |
|
|
<input id="scale" type="number" value="7.5" step="0.1" min="1" max="20"/><br> |
|
|
<label>Seed (optional)</label> |
|
|
<input id="seed" type="number" placeholder="optional seed"/><br> |
|
|
<button onclick="generate()">Generate Image</button> |
|
|
<div id="status"></div> |
|
|
<div id="result"></div> |
|
|
|
|
|
<script> |
|
|
async function generate(){ |
|
|
const prompt = document.getElementById('prompt').value; |
|
|
const steps = parseInt(document.getElementById('steps').value || 25); |
|
|
const scale = parseFloat(document.getElementById('scale').value || 7.5); |
|
|
const seedVal = document.getElementById('seed').value; |
|
|
|
|
|
document.getElementById('status').textContent = "⏳ Generating — this may take a bit..."; |
|
|
document.getElementById('result').innerHTML = ""; |
|
|
|
|
|
const body = { prompt: prompt, num_inference_steps: steps, guidance_scale: scale }; |
|
|
if (seedVal) body.seed = parseInt(seedVal); |
|
|
|
|
|
try { |
|
|
const res = await fetch('/generate', { |
|
|
method: 'POST', |
|
|
headers: { 'Content-Type': 'application/json' }, |
|
|
body: JSON.stringify(body) |
|
|
}); |
|
|
|
|
|
if (!res.ok) { |
|
|
const txt = await res.text(); |
|
|
document.getElementById('status').textContent = '❌ Error ' + res.status + ': ' + txt; |
|
|
return; |
|
|
} |
|
|
|
|
|
const data = await res.json(); |
|
|
document.getElementById('status').textContent = '✅ Done — image below'; |
|
|
document.getElementById('result').innerHTML = `<img src="${data.url}" alt="generated-image"/>`; |
|
|
} catch (err) { |
|
|
document.getElementById('status').textContent = '❌ Exception: ' + err.message; |
|
|
} |
|
|
} |
|
|
</script> |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
return HTMLResponse(content=html) |
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
return {"status": "ok"} |