Roshan1162003's picture
Upload folder using huggingface_hub
4608b26 verified
from flask import Flask, render_template, request, jsonify
import os
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import re
import time
app = Flask(__name__)
# Define paths
MODEL_PATH = "Roshan1162003/fine_tuned_model" # Replace with your HF model repository ID
STATIC_IMAGES_PATH = os.path.join("static", "images")
os.makedirs(STATIC_IMAGES_PATH, exist_ok=True)
# Restricted terms for prompt filtering
RESTRICTED_TERMS = [
"crime", "abuse", "violence", "illegal", "explicit", "nsfw",
"offensive", "hate", "nude", "porn", "gore", "drug"
]
# Load the fine-tuned model
pipe = None
if torch.cuda.is_available():
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float16,
use_safetensors=True,
use_auth_token=os.getenv("HF_TOKEN") # Use HF_TOKEN from environment
).to("cuda")
else:
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float32,
use_safetensors=True,
use_auth_token=os.getenv("HF_TOKEN")
)
print("Model loaded successfully")
# Aspect ratio to resolution mapping
ASPECT_RATIOS = {
"1:1": (512, 512),
"4:3": (512, 384),
"16:9": (512, 288)
}
def is_prompt_safe(prompt):
"""Check if prompt contains restricted terms."""
prompt_lower = prompt.lower()
for term in RESTRICTED_TERMS:
if re.search(r'\b' + re.escape(term) + r'\b', prompt_lower):
return False
return True
@app.route("/", methods=["GET"])
def index():
return render_template("index.html")
@app.route("/generate", methods=["POST"])
def generate():
try:
# Get form data
prompt = request.form.get("prompt", "").strip()
num_images = int(request.form.get("num_images", 1))
aspect_ratio = request.form.get("aspect_ratio", "1:1")
model_name = request.form.get("model", "stable_diffusion")
# Validate inputs
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
if model_name != "stable_diffusion":
return jsonify({"error": "Selected model is locked"}), 400
if num_images < 1 or num_images > 5:
return jsonify({"error": "Number of images must be between 1 and 5"}), 400
if aspect_ratio not in ASPECT_RATIOS:
return jsonify({"error": "Invalid aspect ratio"}), 400
# Check for restricted terms
if not is_prompt_safe(prompt):
return jsonify({
"error": "You are violating the regulation policy terms and conditions due to restricted terms in the prompt."
}), 400
# Get resolution for aspect ratio
width, height = ASPECT_RATIOS[aspect_ratio]
# Generate images
image_paths = []
for i in range(num_images):
image = pipe(
prompt,
width=width,
height=height,
num_inference_steps=50,
guidance_scale=7.5,
seed=42 + i
).images[0]
# Save image
timestamp = int(time.time() * 1000)
image_path = os.path.join(STATIC_IMAGES_PATH, f"generated_{timestamp}_{i}.png")
image.save(image_path)
image_paths.append(image_path.replace("static/", ""))
return jsonify({"images": image_paths})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=False)