ShahbazAhmad-Lab commited on
Commit
a1d9a53
·
verified ·
1 Parent(s): baf6a86

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ import requests
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ import gradio as gr
8
+
9
+ # ----------------------------
10
+ # Logging Configuration
11
+ # ----------------------------
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # ----------------------------
16
+ # Constants
17
+ # ----------------------------
18
+ HF_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
19
+ DEFAULT_STYLES = [
20
+ "Realistic", "Cinematic", "Cyberpunk",
21
+ "Studio Lighting", "Highly Detailed", "4K"
22
+ ]
23
+
24
+ # ----------------------------
25
+ # Utility Functions
26
+ # ----------------------------
27
+
28
+ def get_hf_token():
29
+ """Load Hugging Face token from environment variable."""
30
+ token = os.getenv("HF_TOKEN")
31
+ if not token:
32
+ raise EnvironmentError("HF_TOKEN not found in environment variables.")
33
+ return token
34
+
35
+
36
+ def style_prompt(user_input: str, style: str = None) -> str:
37
+ """Enhance prompt with selected style."""
38
+ if not user_input.strip():
39
+ raise ValueError("Prompt cannot be empty.")
40
+
41
+ if style and style != "None":
42
+ enhanced = f"{user_input}, {style}, ultra quality, sharp focus"
43
+ else:
44
+ enhanced = f"{user_input}, high quality"
45
+
46
+ return enhanced
47
+
48
+
49
+ def query_hf_api(prompt, retries=3, timeout=60, seed=None):
50
+ """Send request to Hugging Face Inference API with retry logic."""
51
+ headers = {
52
+ "Authorization": f"Bearer {get_hf_token()}",
53
+ "Content-Type": "application/json"
54
+ }
55
+
56
+ payload = {
57
+ "inputs": prompt,
58
+ "options": {"wait_for_model": True}
59
+ }
60
+
61
+ if seed is not None:
62
+ payload["parameters"] = {"seed": seed}
63
+
64
+ for attempt in range(retries):
65
+ try:
66
+ response = requests.post(
67
+ HF_API_URL,
68
+ headers=headers,
69
+ json=payload,
70
+ timeout=timeout
71
+ )
72
+
73
+ if response.status_code == 200:
74
+ return response.content
75
+
76
+ elif response.status_code == 503:
77
+ logger.warning("Model loading, retrying...")
78
+ time.sleep(5)
79
+
80
+ elif response.status_code == 429:
81
+ logger.warning("Rate limit hit, retrying...")
82
+ time.sleep(10)
83
+
84
+ else:
85
+ logger.error(f"API Error: {response.text}")
86
+ raise RuntimeError(f"API Error: {response.text}")
87
+
88
+ except requests.exceptions.Timeout:
89
+ logger.warning("Timeout occurred, retrying...")
90
+ time.sleep(5)
91
+
92
+ raise RuntimeError("Failed after multiple retries.")
93
+
94
+
95
+ def generate_image(prompt, style, seed):
96
+ """Main function for Gradio."""
97
+ try:
98
+ styled_prompt = style_prompt(prompt, style)
99
+ image_bytes = query_hf_api(styled_prompt, seed=seed)
100
+
101
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
102
+
103
+ return image
104
+
105
+ except Exception as e:
106
+ logger.error(str(e))
107
+ return f"Error: {str(e)}"
108
+
109
+
110
+ # ----------------------------
111
+ # Gradio UI
112
+ # ----------------------------
113
+
114
+ with gr.Blocks() as app:
115
+ gr.Markdown("# 🎨 AI Image Generator (FLUX.1-schnell)")
116
+ gr.Markdown("Generate high-quality images from text prompts using Hugging Face.")
117
+
118
+ with gr.Row():
119
+ prompt_input = gr.Textbox(
120
+ label="Enter your prompt",
121
+ placeholder="e.g., A futuristic city at sunset"
122
+ )
123
+
124
+ with gr.Row():
125
+ style_dropdown = gr.Dropdown(
126
+ ["None"] + DEFAULT_STYLES,
127
+ label="Select Style",
128
+ value="None"
129
+ )
130
+ seed_input = gr.Number(
131
+ label="Seed (optional)",
132
+ value=None,
133
+ precision=0
134
+ )
135
+
136
+ generate_btn = gr.Button("Generate Image")
137
+
138
+ output_image = gr.Image(label="Generated Image")
139
+ download_btn = gr.File(label="Download Image")
140
+
141
+ examples = gr.Examples(
142
+ examples=[
143
+ ["A dragon flying over mountains", "Cinematic", 42],
144
+ ["Cyberpunk city at night", "Cyberpunk", 123],
145
+ ["Portrait of a warrior", "Realistic", 7],
146
+ ],
147
+ inputs=[prompt_input, style_dropdown, seed_input],
148
+ )
149
+
150
+ def generate_and_download(prompt, style, seed):
151
+ image = generate_image(prompt, style, seed)
152
+ if isinstance(image, str):
153
+ return None, None
154
+
155
+ file_path = "output.png"
156
+ image.save(file_path)
157
+ return image, file_path
158
+
159
+ generate_btn.click(
160
+ fn=generate_and_download,
161
+ inputs=[prompt_input, style_dropdown, seed_input],
162
+ outputs=[output_image, download_btn]
163
+ )
164
+
165
+ app.launch()