File size: 3,660 Bytes
4608b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from flask import Flask, render_template, request, jsonify
import os
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import re
import time

app = Flask(__name__)

# Define paths
MODEL_PATH = "Roshan1162003/fine_tuned_model"  # Replace with your HF model repository ID
STATIC_IMAGES_PATH = os.path.join("static", "images")
os.makedirs(STATIC_IMAGES_PATH, exist_ok=True)

# Restricted terms for prompt filtering
RESTRICTED_TERMS = [
    "crime", "abuse", "violence", "illegal", "explicit", "nsfw",
    "offensive", "hate", "nude", "porn", "gore", "drug"
]

# Load the fine-tuned model
pipe = None
if torch.cuda.is_available():
    pipe = StableDiffusionPipeline.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.float16,
        use_safetensors=True,
        use_auth_token=os.getenv("HF_TOKEN")  # Use HF_TOKEN from environment
    ).to("cuda")
else:
    pipe = StableDiffusionPipeline.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.float32,
        use_safetensors=True,
        use_auth_token=os.getenv("HF_TOKEN")
    )
print("Model loaded successfully")

# Aspect ratio to resolution mapping
ASPECT_RATIOS = {
    "1:1": (512, 512),
    "4:3": (512, 384),
    "16:9": (512, 288)
}

def is_prompt_safe(prompt):
    """Check if prompt contains restricted terms."""
    prompt_lower = prompt.lower()
    for term in RESTRICTED_TERMS:
        if re.search(r'\b' + re.escape(term) + r'\b', prompt_lower):
            return False
    return True

@app.route("/", methods=["GET"])
def index():
    return render_template("index.html")

@app.route("/generate", methods=["POST"])
def generate():
    try:
        # Get form data
        prompt = request.form.get("prompt", "").strip()
        num_images = int(request.form.get("num_images", 1))
        aspect_ratio = request.form.get("aspect_ratio", "1:1")
        model_name = request.form.get("model", "stable_diffusion")

        # Validate inputs
        if not prompt:
            return jsonify({"error": "Prompt is required"}), 400
        if model_name != "stable_diffusion":
            return jsonify({"error": "Selected model is locked"}), 400
        if num_images < 1 or num_images > 5:
            return jsonify({"error": "Number of images must be between 1 and 5"}), 400
        if aspect_ratio not in ASPECT_RATIOS:
            return jsonify({"error": "Invalid aspect ratio"}), 400

        # Check for restricted terms
        if not is_prompt_safe(prompt):
            return jsonify({
                "error": "You are violating the regulation policy terms and conditions due to restricted terms in the prompt."
            }), 400

        # Get resolution for aspect ratio
        width, height = ASPECT_RATIOS[aspect_ratio]

        # Generate images
        image_paths = []
        for i in range(num_images):
            image = pipe(
                prompt,
                width=width,
                height=height,
                num_inference_steps=50,
                guidance_scale=7.5,
                seed=42 + i
            ).images[0]
            # Save image
            timestamp = int(time.time() * 1000)
            image_path = os.path.join(STATIC_IMAGES_PATH, f"generated_{timestamp}_{i}.png")
            image.save(image_path)
            image_paths.append(image_path.replace("static/", ""))

        return jsonify({"images": image_paths})

    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860, debug=False)