Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import openai | |
| import requests | |
| import os | |
| from dotenv import load_dotenv | |
| import io | |
| import sys | |
| import json | |
| import PIL | |
| import time | |
| from stability_sdk import client | |
| import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation | |
| import markdown2 | |
| title="najimino AI recipe generator" | |
| inputs_label="どんな料理か教えてくれれば,新しいレシピを考えます" | |
| outputs_label="najimino AIが返信をします" | |
| visual_outputs_label="料理のイメージ" | |
| description=""" | |
| - ※入出力の文字数は最大1000文字程度までを目安に入力してください。回答に50秒くらいかかります. | |
| """ | |
| article = """ | |
| """ | |
| load_dotenv() | |
| openai.api_key = os.getenv('OPENAI_API_KEY') | |
| os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443' | |
| stability_api = client.StabilityInference( | |
| key=os.getenv('STABILITY_KEY'), | |
| verbose=True, | |
| # engine="stable-diffusion-512-v2-1", | |
| # engine="stable-diffusion-xl-beta-v2-2-2", | |
| # engine="stable-diffusion-xl-1024-v0-9", | |
| engine="stable-diffusion-xl-1024-v1-0", | |
| # Available engines: stable-diffusion-v1 stable-diffusion-v1-5 stable-diffusion-512-v2-0 stable-diffusion-768-v2-0 | |
| # stable-diffusion-512-v2-1 stable-diffusion-768-v2-1 stable-diffusion-xl-beta-v2-2-2 stable-inpainting-v1-0 stable-inpainting-512-v2-0 | |
| ) | |
| # MODEL = "gpt-4" | |
| # MODEL = "gpt-3.5-turbo-16k" | |
| # MODEL = "gpt-3.5-turbo-0613" | |
| MODEL = "gpt-3.5-turbo-1106" | |
| def get_filetext(filename, cache={}): | |
| if filename in cache: | |
| # キャッシュに保存されている場合は、キャッシュからファイル内容を取得する | |
| return cache[filename] | |
| else: | |
| if not os.path.exists(filename): | |
| raise ValueError(f"ファイル '{filename}' が見つかりませんでした") | |
| with open(filename, "r") as f: | |
| text = f.read() | |
| # ファイル内容をキャッシュする | |
| cache[filename] = text | |
| return text | |
| def get_functions_from_schema(filename): | |
| schema = get_filetext(filename) | |
| schema_json = json.loads(schema) | |
| functions = schema_json.get("functions") | |
| return functions | |
| class StabilityAI: | |
| def generate_image(cls, visualize_prompt): | |
| print("visualize_prompt:"+visualize_prompt) | |
| answers = stability_api.generate( | |
| prompt=visualize_prompt, | |
| ) | |
| for resp in answers: | |
| for artifact in resp.artifacts: | |
| if artifact.finish_reason == generation.FILTER: | |
| print("NSFW") | |
| if artifact.type == generation.ARTIFACT_IMAGE: | |
| img = PIL.Image.open(io.BytesIO(artifact.binary)) | |
| return img | |
| class OpenAI: | |
| def chat_completion(cls, prompt, start_with=""): | |
| constraints = get_filetext(filename = "constraints.md") | |
| template = get_filetext(filename = "template.md") | |
| # ChatCompletion APIに渡すデータを定義する | |
| data = { | |
| "model": MODEL, | |
| "messages": [ | |
| {"role": "system", "content": constraints} | |
| ,{"role": "system", "content": template} | |
| ,{"role": "assistant", "content": "Sure!"} | |
| ,{"role": "user", "content": prompt} | |
| ,{"role": "assistant", "content": start_with} | |
| ], | |
| } | |
| # 文章生成にかかる時間を計測する | |
| start = time.time() | |
| # ChatCompletion APIを呼び出す | |
| response = requests.post( | |
| "https://api.openai.com/v1/chat/completions", | |
| headers={ | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {openai.api_key}" | |
| }, | |
| json=data | |
| ) | |
| print("gpt generation time: "+str(time.time() - start)) | |
| # ChatCompletion APIから返された結果を取得する | |
| result = response.json() | |
| print(result) | |
| content = result["choices"][0]["message"]["content"].strip() | |
| visualize_prompt = content.split("### Prompt for Visual Expression\n\n")[1] | |
| #print("split_content:"+split_content) | |
| #if len(split_content) > 1: | |
| # visualize_prompt = split_content[1] | |
| #else: | |
| # visualize_prompt = "vacant dish" | |
| #print("visualize_prompt:"+visualize_prompt) | |
| answers = stability_api.generate( | |
| prompt=visualize_prompt, | |
| ) | |
| def chat_completion_with_function(cls, prompt, messages, functions): | |
| print("prompt:"+prompt) | |
| # 文章生成にかかる時間を計測する | |
| start = time.time() | |
| # ChatCompletion APIを呼び出す | |
| response = openai.ChatCompletion.create( | |
| model=MODEL, | |
| messages=messages, | |
| functions=functions, | |
| function_call={"name": "format_recipe"} | |
| ) | |
| print("gpt generation time: "+str(time.time() - start)) | |
| # ChatCompletion APIから返された結果を取得する | |
| message = response.choices[0].message | |
| print("chat completion message: " + json.dumps(message, indent=2)) | |
| return message | |
| class NajiminoAI: | |
| def __init__(self, user_message): | |
| self.user_message = user_message | |
| def generate_recipe_prompt(self): | |
| template = get_filetext(filename="template.md") | |
| prompt = f""" | |
| {self.user_message} | |
| --- | |
| 上記を元に、下記テンプレートを埋めてください。 | |
| --- | |
| {template} | |
| """ | |
| return prompt | |
| def format_recipe(self, lang, title, description, ingredients, instruction, comment_feelings_taste, explanation_to_blind_person, prompt_for_visual_expression): | |
| template = get_filetext(filename = "template.md") | |
| debug_message = template.format( | |
| lang=lang, | |
| title=title, | |
| description=description, | |
| ingredients=ingredients, | |
| instruction=instruction, | |
| comment_feelings_taste=comment_feelings_taste, | |
| explanation_to_blind_person=explanation_to_blind_person, | |
| prompt_for_visual_expression=prompt_for_visual_expression | |
| ) | |
| print("debug_message: "+debug_message) | |
| return debug_message | |
| def generate(cls, user_message): | |
| najiminoai = NajiminoAI(user_message) | |
| return najiminoai.generate_recipe() | |
| def generate_recipe(self): | |
| user_message = self.user_message | |
| constraints = get_filetext(filename = "constraints.md") | |
| messages = [ | |
| {"role": "system", "content": constraints} | |
| ,{"role": "user", "content": user_message} | |
| ] | |
| functions = get_functions_from_schema('schema.json') | |
| message = OpenAI.chat_completion_with_function(prompt=user_message, messages=messages, functions=functions) | |
| image = None | |
| html = None | |
| if message.get("function_call"): | |
| function_name = message["function_call"]["name"] | |
| args = json.loads(message["function_call"]["arguments"]) | |
| lang=args.get("lang") | |
| title=args.get("title") | |
| description=args.get("description") | |
| ingredients=args.get("ingredients") | |
| instruction=args.get("instruction") | |
| comment_feelings_taste=args.get("comment_feelings_taste") | |
| explanation_to_blind_person=args.get("explanation_to_blind_person") | |
| prompt_for_visual_expression_in_en=args.get("prompt_for_visual_expression_in_en") | |
| prompt_for_visual_expression = \ | |
| prompt_for_visual_expression_in_en \ | |
| + " delicious looking extremely detailed photo f1.2 (50mm|85mm) award winner depth of field bokeh perfect lighting " | |
| print("prompt_for_visual_expression: "+prompt_for_visual_expression) | |
| # 画像生成にかかる時間を計測する | |
| start = time.time() | |
| image = StabilityAI.generate_image(prompt_for_visual_expression) | |
| print("image generation time: "+str(time.time() - start)) | |
| function_response = self.format_recipe( | |
| lang=lang, | |
| title=title, | |
| description=description, | |
| ingredients=ingredients, | |
| instruction=instruction, | |
| comment_feelings_taste=comment_feelings_taste, | |
| explanation_to_blind_person=explanation_to_blind_person, | |
| prompt_for_visual_expression=prompt_for_visual_expression | |
| ) | |
| html = ( | |
| "<div style='max-width:100%; overflow:auto'>" | |
| + "<p>" | |
| + markdown2.markdown(function_response) | |
| + "</div>" | |
| ) | |
| return [image, html] | |
| def main(): | |
| # インプット例をクリックした時のコールバック関数 | |
| def click_example(example): | |
| # クリックされたインプット例をテキストボックスに自動入力 | |
| inputs.value = example | |
| time.sleep(0.1) # テキストボックスに文字が表示されるまで待機 | |
| # 自動入力後に実行ボタンをクリックして結果を表示 | |
| execute_button.click() | |
| iface = gr.Interface(fn=NajiminoAI.generate, | |
| examples=[ | |
| ["ラー麺 スイカ かき氷 八ツ橋"], | |
| ["お好み焼き 鯖"], | |
| ["茹でたアスパラガスに合う季節のソース"], | |
| ], | |
| inputs=gr.Textbox(label=inputs_label), | |
| outputs=[ | |
| gr.Image(label="Visual Expression"), | |
| "html" | |
| ], | |
| title=title, | |
| description=description, | |
| article=article | |
| ) | |
| iface.launch() | |
| if __name__ == '__main__': | |
| function = '' | |
| if len(sys.argv) > 1: | |
| function = sys.argv[1] | |
| if function == 'generate': | |
| NajiminoAI.generate("グルテンフリーの香ばしいサバのお好み焼き") | |
| elif function == 'generate_image': | |
| image = StabilityAI.generate_image("Imagine a delicious gluten-free okonomiyaki with mackerel. The okonomiyaki is crispy on the outside and chewy on the inside. It is topped with savory sauce and creamy mayonnaise, creating a mouthwatering visual. The dish is garnished with finely chopped green onions and red pickled ginger, adding a pop of color. The mackerel fillets are beautifully grilled and placed on top of the okonomiyaki, adding a touch of elegance. The dish is served on a traditional Japanese plate, completing the visual presentation.") | |
| print("image: " + image) | |
| # imageが何のクラス確認する | |
| if type(image) == PIL.PngImagePlugin.PngImageFile: | |
| #save image | |
| image.save("image.png") | |
| else: | |
| main() | |