|
|
from flask import Flask, render_template, request, jsonify
|
|
|
import os
|
|
|
import torch
|
|
|
from diffusers import StableDiffusionPipeline
|
|
|
from PIL import Image
|
|
|
import re
|
|
|
import time
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
|
|
|
|
MODEL_PATH = "Roshan1162003/fine_tuned_model"
|
|
|
STATIC_IMAGES_PATH = os.path.join("static", "images")
|
|
|
os.makedirs(STATIC_IMAGES_PATH, exist_ok=True)
|
|
|
|
|
|
|
|
|
RESTRICTED_TERMS = [
|
|
|
"crime", "abuse", "violence", "illegal", "explicit", "nsfw",
|
|
|
"offensive", "hate", "nude", "porn", "gore", "drug"
|
|
|
]
|
|
|
|
|
|
|
|
|
pipe = None
|
|
|
if torch.cuda.is_available():
|
|
|
pipe = StableDiffusionPipeline.from_pretrained(
|
|
|
MODEL_PATH,
|
|
|
torch_dtype=torch.float16,
|
|
|
use_safetensors=True,
|
|
|
use_auth_token=os.getenv("HF_TOKEN")
|
|
|
).to("cuda")
|
|
|
else:
|
|
|
pipe = StableDiffusionPipeline.from_pretrained(
|
|
|
MODEL_PATH,
|
|
|
torch_dtype=torch.float32,
|
|
|
use_safetensors=True,
|
|
|
use_auth_token=os.getenv("HF_TOKEN")
|
|
|
)
|
|
|
print("Model loaded successfully")
|
|
|
|
|
|
|
|
|
ASPECT_RATIOS = {
|
|
|
"1:1": (512, 512),
|
|
|
"4:3": (512, 384),
|
|
|
"16:9": (512, 288)
|
|
|
}
|
|
|
|
|
|
def is_prompt_safe(prompt):
|
|
|
"""Check if prompt contains restricted terms."""
|
|
|
prompt_lower = prompt.lower()
|
|
|
for term in RESTRICTED_TERMS:
|
|
|
if re.search(r'\b' + re.escape(term) + r'\b', prompt_lower):
|
|
|
return False
|
|
|
return True
|
|
|
|
|
|
@app.route("/", methods=["GET"])
|
|
|
def index():
|
|
|
return render_template("index.html")
|
|
|
|
|
|
@app.route("/generate", methods=["POST"])
|
|
|
def generate():
|
|
|
try:
|
|
|
|
|
|
prompt = request.form.get("prompt", "").strip()
|
|
|
num_images = int(request.form.get("num_images", 1))
|
|
|
aspect_ratio = request.form.get("aspect_ratio", "1:1")
|
|
|
model_name = request.form.get("model", "stable_diffusion")
|
|
|
|
|
|
|
|
|
if not prompt:
|
|
|
return jsonify({"error": "Prompt is required"}), 400
|
|
|
if model_name != "stable_diffusion":
|
|
|
return jsonify({"error": "Selected model is locked"}), 400
|
|
|
if num_images < 1 or num_images > 5:
|
|
|
return jsonify({"error": "Number of images must be between 1 and 5"}), 400
|
|
|
if aspect_ratio not in ASPECT_RATIOS:
|
|
|
return jsonify({"error": "Invalid aspect ratio"}), 400
|
|
|
|
|
|
|
|
|
if not is_prompt_safe(prompt):
|
|
|
return jsonify({
|
|
|
"error": "You are violating the regulation policy terms and conditions due to restricted terms in the prompt."
|
|
|
}), 400
|
|
|
|
|
|
|
|
|
width, height = ASPECT_RATIOS[aspect_ratio]
|
|
|
|
|
|
|
|
|
image_paths = []
|
|
|
for i in range(num_images):
|
|
|
image = pipe(
|
|
|
prompt,
|
|
|
width=width,
|
|
|
height=height,
|
|
|
num_inference_steps=50,
|
|
|
guidance_scale=7.5,
|
|
|
seed=42 + i
|
|
|
).images[0]
|
|
|
|
|
|
timestamp = int(time.time() * 1000)
|
|
|
image_path = os.path.join(STATIC_IMAGES_PATH, f"generated_{timestamp}_{i}.png")
|
|
|
image.save(image_path)
|
|
|
image_paths.append(image_path.replace("static/", ""))
|
|
|
|
|
|
return jsonify({"images": image_paths})
|
|
|
|
|
|
except Exception as e:
|
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
app.run(host="0.0.0.0", port=7860, debug=False) |