meme_diffusion / app.py
Omnibus's picture
Update app.py
351b316 verified
from huggingface_hub import InferenceClient
from PIL import Image,ImageFont,ImageDraw
import gradio as gr
import requests
import random
import uuid
import io
from utils import models, MEME_GENERATOR,GENERATE_PROMPT
client = InferenceClient(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
loaded_model=[]
for i,model in enumerate(models):
loaded_model.append(gr.load(f'models/{model}', cache_examples=False))
def textover(im,text):
x=0
y=0
t_fill = (0,0,0)
font_size=20
draw = ImageDraw.Draw(im)
font = ImageFont.truetype("./fonts/unifont-15.0.01.ttf", int(font_size))
draw.text((x, y),text, font = font, fill=t_fill)
return im
def get_concat_h_cut(in1, in2):
im1=Image.open(in1)
im2=Image.open(in2)
dst = Image.new('RGB', (im1.width + im2.width,
min(im1.height, im2.height)))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width, 0))
return dst
def get_concat_v_cut(im1,theme='light'):
#im1=Image.open(in1)
if theme=='dark':
color=(31,41,55)
if theme=='light':
color=(255,255,255)
dst = Image.new('RGB', (im1.width, im1.height +200),color=color)
dst.paste(im1, (0, 200))
return dst
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
agents =[
"MEME_GENERATOR",
]
def generate(prompt, history):
print(f'HISTORY:: {history}')
history=[]
output1={}
seed = random.randint(1,1111111111111111)
system_prompt=MEME_GENERATOR
generate_kwargs = dict(
temperature=0.7,
max_new_tokens=256,
top_p=0.95,
repetition_penalty=1,
do_sample=True,
seed=seed,
)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield [(prompt,output)],output1
if "PROMPT:" and "MEME_TEXT:" in output:
print("YES")
prompt_t=output.split("PROMPT:",1)[1].split("MEME_TEXT:",1)[0].strip()
print(prompt_t)
meme_t=output.split("MEME_TEXT:",1)[1].strip()
print(meme_t)
output1={'PROMPT':prompt_t,'MEME_TEXT':meme_t}
#output=str(output1)
yield [(prompt,output)],output1
def run(inp,model_drop):
prompt=inp['PROMPT']
text=inp['MEME_TEXT']
model=loaded_model[int(model_drop)]
out_img=model(prompt)
print(out_img)
url=f'https://omnibus-meme-diffusion.hf.space/file={out_img}'
print(url)
uid = uuid.uuid4()
r = requests.get(url, stream=True)
if r.status_code == 200:
out = Image.open(io.BytesIO(r.content))
out=get_concat_v_cut(out)
out=textover(out,text)
return out
def run_gpt(in_prompt,history,):
if len(in_prompt)>max_prompt:
in_prompt = condense(in_prompt)
print(f'history :: {history}')
prompt=format_prompt(in_prompt,history)
seed = random.randint(1,1111111111111111)
print (seed)
generate_kwargs = dict(
temperature=1.0,
max_new_tokens=1048,
top_p=0.99,
repetition_penalty=1.0,
do_sample=True,
seed=seed,
)
content = GENERATE_PROMPT + prompt
print(content)
stream = client.text_generation(content, **generate_kwargs, stream=True, details=True, return_full_text=False)
resp = ""
for response in stream:
resp += response.token.text
return resp
with gr.Blocks() as app:
gr.HTML("""<center><h1>Meme Diffusion</h1></center>""")
with gr.Row():
with gr.Column(scale=1):
chatbot=gr.Chatbot()
msg = gr.Textbox()
model_drop=gr.Dropdown(label="Diffusion Models", type="index", choices=[m for m in models], value=models[0])
with gr.Group():
submit_b = gr.Button("Meme")
submit_im = gr.Button("Image")
with gr.Row():
stop_b = gr.Button("Stop")
clear = gr.ClearButton([msg, chatbot])
with gr.Column(scale=2):
im_out=gr.Image(label="Image")
json_out=gr.JSON()
sub_b = submit_b.click(generate, [msg,chatbot],[chatbot,json_out])
sub_im = submit_im.click(run, [json_out,model_drop],[im_out])
stop_b.click(None,None,None, cancels=[sub_b,sub_im])
app.launch()