K00B404 commited on
Commit
9b5f4e1
·
verified ·
1 Parent(s): 094e3a9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import io
4
+ import random
5
+ import os
6
+ import time
7
+ from PIL import Image
8
+ from deep_translator import GoogleTranslator
9
+ import json
10
+ from gradio_client import Client
11
+
12
+ # Project by Nymbo
13
+
14
+ API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-3.5-large"
15
+ API_TOKEN = os.getenv("HF_READ_TOKEN")
16
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
17
+ timeout = 100
18
+
19
+ # Initialize the prompt enhancer client
20
+ prompt_enhancer = Client("K00B404/mistral-nemo-prompt-enhancer")
21
+
22
+ def enhance_prompt(prompt):
23
+ """Enhance the given prompt using the Mistral Nemo prompt enhancer."""
24
+ try:
25
+ system_message = "You are an expert at writing detailed, high-quality image generation prompts. Enhance the given prompt by adding relevant artistic details, style elements, and quality descriptors. Keep the original intent but make it more elaborate and specific."
26
+ enhanced = prompt_enhancer.predict(
27
+ message=prompt,
28
+ system_message=system_message,
29
+ max_tokens=512,
30
+ temperature=0.7,
31
+ top_p=0.95,
32
+ api_name="/chat"
33
+ )
34
+ print(f'\033[1mOriginal prompt:\033[0m {prompt}')
35
+ print(f'\033[1mEnhanced prompt:\033[0m {enhanced}')
36
+ return enhanced
37
+ except Exception as e:
38
+ print(f"Error enhancing prompt: {e}")
39
+ return prompt # Fall back to original prompt if enhancement fails
40
+
41
+ # Function to query the API and return the generated image
42
+ def query(prompt, is_negative=False, steps=35, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024):
43
+ if prompt == "" or prompt is None:
44
+ return None
45
+
46
+ key = random.randint(0, 999)
47
+
48
+ API_TOKEN = random.choice([os.getenv("HF_READ_TOKEN")])
49
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
50
+
51
+ # Translate the prompt from Russian to English if necessary
52
+ prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
53
+ print(f'\033[1mGeneration {key} translation:\033[0m {prompt}')
54
+
55
+ # Enhance the prompt using the Mistral Nemo model
56
+ prompt = enhance_prompt(prompt)
57
+
58
+ # Add some extra flair to the prompt
59
+ prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
60
+ print(f'\033[1mGeneration {key} final prompt:\033[0m {prompt}')
61
+
62
+ # Prepare the payload for the API call, including width and height
63
+ payload = {
64
+ "inputs": prompt,
65
+ "is_negative": is_negative,
66
+ "steps": steps,
67
+ "cfg_scale": cfg_scale,
68
+ "seed": seed if seed != -1 else random.randint(1, 1000000000),
69
+ "strength": strength,
70
+ "parameters": {
71
+ "width": width,
72
+ "height": height
73
+ }
74
+ }
75
+
76
+ # Send the request to the API and handle the response
77
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
78
+ if response.status_code != 200:
79
+ print(f"Error: Failed to get image. Response status: {response.status_code}")
80
+ print(f"Response content: {response.text}")
81
+ if response.status_code == 503:
82
+ raise gr.Error(f"{response.status_code} : The model is being loaded")
83
+ raise gr.Error(f"{response.status_code}")
84
+
85
+ try:
86
+ # Convert the response content into an image
87
+ image_bytes = response.content
88
+ image = Image.open(io.BytesIO(image_bytes))
89
+ print(f'\033[1mGeneration {key} completed!\033[0m ({prompt})')
90
+ return image
91
+ except Exception as e:
92
+ print(f"Error when trying to open the image: {e}")
93
+ return None
94
+
95
+ # CSS to style the app
96
+ css = """
97
+ #app-container {
98
+ max-width: 800px;
99
+ margin-left: auto;
100
+ margin-right: auto;
101
+ }
102
+ """
103
+
104
+ # Build the Gradio UI with Blocks
105
+ with gr.Blocks(theme='Nymbo/Nymbo_Theme', css=css) as app:
106
+ # Add a title to the app
107
+ gr.HTML("<center><h1>Stable Diffusion 3.5 Large with Prompt Enhancement</h1></center>")
108
+
109
+ # Container for all the UI elements
110
+ with gr.Column(elem_id="app-container"):
111
+ # Add a text input for the main prompt
112
+ with gr.Row():
113
+ with gr.Column(elem_id="prompt-container"):
114
+ with gr.Row():
115
+ text_prompt = gr.Textbox(
116
+ label="Prompt",
117
+ placeholder="Enter a prompt here - it will be automatically enhanced for better results",
118
+ lines=2,
119
+ elem_id="prompt-text-input"
120
+ )
121
+
122
+ # Accordion for advanced settings
123
+ with gr.Row():
124
+ with gr.Accordion("Advanced Settings", open=False):
125
+ negative_prompt = gr.Textbox(
126
+ label="Negative Prompt",
127
+ placeholder="What should not be in the image",
128
+ value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos",
129
+ lines=3,
130
+ elem_id="negative-prompt-text-input"
131
+ )
132
+ with gr.Row():
133
+ width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
134
+ height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
135
+ steps = gr.Slider(label="Sampling steps", value=35, minimum=1, maximum=100, step=1)
136
+ cfg = gr.Slider(label="CFG Scale", value=7, minimum=1, maximum=20, step=1)
137
+ strength = gr.Slider(label="Strength", value=0.7, minimum=0, maximum=1, step=0.001)
138
+ seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1)
139
+ method = gr.Radio(
140
+ label="Sampling method",
141
+ value="DPM++ 2M Karras",
142
+ choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"]
143
+ )
144
+
145
+ # Add a button to trigger the image generation
146
+ with gr.Row():
147
+ text_button = gr.Button("Generate Enhanced Image", variant='primary', elem_id="gen-button")
148
+
149
+ # Image output area to display the generated image
150
+ with gr.Row():
151
+ image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
152
+
153
+ # Bind the button to the query function with all inputs
154
+ text_button.click(
155
+ query,
156
+ inputs=[text_prompt, negative_prompt, steps, cfg, method, seed, strength, width, height],
157
+ outputs=image_output
158
+ )
159
+
160
+ # Launch the Gradio app
161
+ app.launch(show_api=False, share=False)