from flask import Flask, render_template, request, jsonify import os import torch from diffusers import StableDiffusionPipeline from PIL import Image import re import time app = Flask(__name__) # Define paths MODEL_PATH = "Roshan1162003/fine_tuned_model" # Replace with your HF model repository ID STATIC_IMAGES_PATH = os.path.join("static", "images") os.makedirs(STATIC_IMAGES_PATH, exist_ok=True) # Restricted terms for prompt filtering RESTRICTED_TERMS = [ "crime", "abuse", "violence", "illegal", "explicit", "nsfw", "offensive", "hate", "nude", "porn", "gore", "drug" ] # Load the fine-tuned model pipe = None if torch.cuda.is_available(): pipe = StableDiffusionPipeline.from_pretrained( MODEL_PATH, torch_dtype=torch.float16, use_safetensors=True, use_auth_token=os.getenv("HF_TOKEN") # Use HF_TOKEN from environment ).to("cuda") else: pipe = StableDiffusionPipeline.from_pretrained( MODEL_PATH, torch_dtype=torch.float32, use_safetensors=True, use_auth_token=os.getenv("HF_TOKEN") ) print("Model loaded successfully") # Aspect ratio to resolution mapping ASPECT_RATIOS = { "1:1": (512, 512), "4:3": (512, 384), "16:9": (512, 288) } def is_prompt_safe(prompt): """Check if prompt contains restricted terms.""" prompt_lower = prompt.lower() for term in RESTRICTED_TERMS: if re.search(r'\b' + re.escape(term) + r'\b', prompt_lower): return False return True @app.route("/", methods=["GET"]) def index(): return render_template("index.html") @app.route("/generate", methods=["POST"]) def generate(): try: # Get form data prompt = request.form.get("prompt", "").strip() num_images = int(request.form.get("num_images", 1)) aspect_ratio = request.form.get("aspect_ratio", "1:1") model_name = request.form.get("model", "stable_diffusion") # Validate inputs if not prompt: return jsonify({"error": "Prompt is required"}), 400 if model_name != "stable_diffusion": return jsonify({"error": "Selected model is locked"}), 400 if num_images < 1 or num_images > 5: return jsonify({"error": "Number of images must be between 1 and 5"}), 400 if aspect_ratio not in ASPECT_RATIOS: return jsonify({"error": "Invalid aspect ratio"}), 400 # Check for restricted terms if not is_prompt_safe(prompt): return jsonify({ "error": "You are violating the regulation policy terms and conditions due to restricted terms in the prompt." }), 400 # Get resolution for aspect ratio width, height = ASPECT_RATIOS[aspect_ratio] # Generate images image_paths = [] for i in range(num_images): image = pipe( prompt, width=width, height=height, num_inference_steps=50, guidance_scale=7.5, seed=42 + i ).images[0] # Save image timestamp = int(time.time() * 1000) image_path = os.path.join(STATIC_IMAGES_PATH, f"generated_{timestamp}_{i}.png") image.save(image_path) image_paths.append(image_path.replace("static/", "")) return jsonify({"images": image_paths}) except Exception as e: return jsonify({"error": str(e)}), 500 if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=False)