File size: 1,719 Bytes
14e4303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import os
from huggingface_hub import InferenceClient, login

login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))


repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"

from huggingface_hub import InferenceClient



def image_generator(text_input, style):

  system_input = f"You are an expert prompt engineer with artistic flair. "
  user_input = f"Write a concise prompt for a {style} image containing {text_input}. Only return the prompt."

  messages = [
      {"role": "system", "content": system_input},
      {"role": "user", "content": user_input},
  ]

  client = InferenceClient( repo_id, ) 

  chat_completion = client.chat_completion(
    messages=messages,
    max_tokens=500,
    )
  prompt = chat_completion.choices[0].message.content

  client = InferenceClient()

  image = client.text_to_image(
      prompt=prompt,
      model="stabilityai/stable-diffusion-xl-base-1.0",
      guidance_scale=8,
      seed=42,
  )
  return prompt, image



with gr.Blocks() as demo:
  with gr.Row():
    with gr.Column():
      input_text = gr.Textbox(label="Prompt")
      style = gr.Radio(["fun", "interesting"])
      prompt = gr.Textbox(interactive=False, visible=True, label="Refined prompt")
      output_image = gr.Image(interactive=False, label="Result")

      with gr.Row():
        reset = gr.ClearButton([input_text])
        submit = gr.Button("Submit")
    with gr.Column():
      submit.click(fn=image_generator, inputs=[input_text, style], outputs=[prompt, output_image])

  examples = gr.Examples(
      examples=[
          ["a llama and a cookbook", "fun"],
          ["a squirrel", "interesting"],
      ],
      inputs=[input_text, style]),


if __name__ == "__main__":
    demo.launch()