Spaces:
Build error
Build error
| from openai import OpenAI | |
| import gradio as gr | |
| import requests | |
| from PIL import Image | |
| import numpy as np | |
| import ipadic | |
| import MeCab | |
| import difflib | |
| import io | |
| import os | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| def generate_image(text): | |
| image_path = f"./{text}.png" | |
| if not os.path.exists(image_path): | |
| response = client.images.generate( | |
| model="dall-e-3", | |
| prompt=text, | |
| size="1024x1024", | |
| quality="standard", | |
| n=1, | |
| ) | |
| image_url = response.data[0].url | |
| image_data = requests.get(image_url).content | |
| img = Image.open(io.BytesIO((image_data))) | |
| img = img.resize((512, 512)) | |
| img.save(image_path) | |
| return image_path | |
| def cos_sim(v1, v2): | |
| return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) | |
| def calculate_similarity_score(ori_text, text): | |
| if ori_text != text: | |
| response = client.embeddings.create( | |
| input=[ori_text, text], model="text-embedding-3-small" | |
| ) | |
| score = cos_sim(response.data[0].embedding, response.data[1].embedding) | |
| score = int(round(score, 2) * 100) | |
| score = 99 if score == 100 else score | |
| else: | |
| score = 100 | |
| return score | |
| def tokenize_text(text): | |
| mecab = MeCab.Tagger(f"-Ochasen {ipadic.MECAB_ARGS}") | |
| return [t.split()[0] for t in mecab.parse(text).splitlines()[:-1]] | |
| def create_match_words(ori_text, text): | |
| ori_words = tokenize_text(ori_text) | |
| words = tokenize_text(text) | |
| match_words = [w for w in words if w in ori_words] | |
| return match_words | |
| def create_hint_text(ori_text, text): | |
| response = list(difflib.ndiff(list(text), list(ori_text))) | |
| output = "" | |
| for r in response: | |
| if r[:2] == "- ": | |
| continue | |
| elif r[:2] == "+ ": | |
| output += "X" | |
| else: | |
| output += r.strip() | |
| return output | |
| def update_question(option): | |
| answer = os.getenv(option) | |
| return f"./{answer}.png" | |
| def main(text, option): | |
| ori_text = os.getenv(option) | |
| image_path = generate_image(text) | |
| score = calculate_similarity_score(ori_text, text) | |
| if score < 80: | |
| match_words = create_match_words(ori_text, text) | |
| hint_text = "一致している単語リスト: " + " ".join(match_words) | |
| elif 80 <= score < 100: | |
| hint_text = "一致していない箇所: " + create_hint_text(ori_text, text) | |
| else: | |
| hint_text = "" | |
| return image_path, f"{score}点", hint_text | |
| def auth(user_name, password): | |
| if user_name == os.getenv("USER_NAME") and password == os.getenv("PASSWORD"): | |
| return True | |
| else: | |
| return False | |
| questions = ["Q1", "Q2", "Q3"] | |
| for q in questions: | |
| image_path = generate_image(os.getenv(q)) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| "# プロンプトを当てるゲーム \n これは表示されている画像のプロンプトを当てるゲームです。プロンプトを入力するとそれに対応した画像とスコアとヒントが表示されます。スコア100点を目指して頑張ってください! \n\nヒントは80点未満の場合は当たっている単語(順番は合っているとは限らない)、80点以上の場合は足りない文字を「X」で示した文字列を表示しています。", | |
| ) | |
| option = gr.components.Radio( | |
| ["Q1", "Q2", "Q3"], label="問題を選んでください!" | |
| ) | |
| output_title_image = gr.components.Image(type="filepath", label="お題") | |
| option.change( | |
| update_question, inputs=[option], outputs=[output_title_image] | |
| ) | |
| input_text = gr.components.Textbox( | |
| lines=1, label="画像にマッチするテキストを入力して!" | |
| ) | |
| submit_button = gr.Button("Submit!") | |
| with gr.Column(): | |
| output_image = gr.components.Image(type="filepath", label="生成画像") | |
| output_score = gr.components.Textbox(lines=1, label="スコア") | |
| output_hint_text = gr.components.Textbox(lines=1, label="ヒント") | |
| submit_button.click( | |
| main, | |
| inputs=[input_text, option], | |
| outputs=[output_image, output_score, output_hint_text], | |
| ) | |
| demo.launch(auth=auth) | |