Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from PIL import Image | |
| import time | |
| import os | |
| import base64 | |
| from io import BytesIO | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret." | |
| else: | |
| HF_TOKEN_ERROR = None | |
| client = InferenceClient(token=HF_TOKEN) | |
| PROMPT_IMPROVER_MODEL = "Qwen/Qwen2.5-Coder-32B-Instruct" | |
| def improve_prompt(original_prompt): | |
| if HF_TOKEN_ERROR: | |
| raise gr.Error(HF_TOKEN_ERROR) | |
| try: | |
| system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent." | |
| prompt_for_llm = f"""<|system|> | |
| {system_prompt}</s> | |
| <|user|> | |
| Improve this prompt: {original_prompt} | |
| </s> | |
| <|assistant|> | |
| """ | |
| improved_prompt = client.text_generation( | |
| prompt=prompt_for_llm, | |
| model=PROMPT_IMPROVER_MODEL, | |
| max_new_tokens=1280, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| stop_sequences=["</s>"], | |
| ) | |
| return improved_prompt.strip() | |
| except Exception as e: | |
| print(f"Error improving prompt: {e}") | |
| return original_prompt | |
| def generate_image(prompt, progress=gr.Progress()): | |
| if HF_TOKEN_ERROR: | |
| raise gr.Error(HF_TOKEN_ERROR) | |
| progress(0, desc="Improving prompt...") | |
| improved_prompt = improve_prompt(prompt) | |
| progress(0.2, desc="Sending request ") | |
| try: | |
| image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell") | |
| if not isinstance(image, Image.Image): | |
| raise Exception(f"Expected a PIL Image, but got: {type(image)}") | |
| progress(0.8, desc="Processing image...") | |
| time.sleep(0.5) | |
| progress(1.0, desc="Done!") | |
| return image | |
| except Exception as e: | |
| if "rate limit" in str(e).lower(): | |
| error_message = f"Rate limit exceeded. Please try again later. Error: {e}" | |
| else: | |
| error_message = f"An error occurred: {e}" | |
| raise gr.Error(error_message) | |
| def pil_to_base64(img): | |
| buffered = BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/png;base64,{img_str}" | |
| css = """ | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| # Xylaria Iris v3 | |
| """, | |
| elem_classes="title" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(elem_classes="input-section"): | |
| prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., A cat", lines=3) | |
| generate_button = gr.Button("Generate Image", elem_classes="submit-button") | |
| with gr.Column(): | |
| with gr.Group(elem_classes="output-section") as output_group: | |
| image_output = gr.Image(label="Generated Image", interactive=False) | |
| def on_generate_click(prompt): | |
| output_group.elem_classes = ["output-section", "animate"] | |
| image = generate_image(prompt) # Ignore the improved prompt | |
| output_group.elem_classes = ["output-section"] | |
| return image # Return only the generated image | |
| generate_button.click(on_generate_click, inputs=prompt_input, outputs=image_output) | |
| prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=image_output) | |
| gr.Examples( | |
| [["A dog"], | |
| ["A house on a hill"], | |
| ["A spaceship"]], | |
| inputs=prompt_input | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |