Spaces:
Running
Running
| import base64 | |
| import io | |
| import random | |
| from textwrap import dedent | |
| import gradio as gr | |
| from openai import OpenAI | |
| from PIL import Image | |
| from pydantic import BaseModel | |
| from prompts import ( | |
| concept_generation_system_prompt, | |
| data_processing_generation_system_prompt, | |
| evaluator_system_prompt, | |
| fusion_generation_system_prompt, | |
| question_bias_generation_system_prompt, | |
| reasoning_generation_system_prompt, | |
| refine_system_prompt_concept, | |
| refine_system_prompt_data, | |
| refine_system_prompt_question_bias, | |
| refine_system_prompt_reason, | |
| refine_system_prompt_visual, | |
| refiner_system_prompt, | |
| review_system_prompt, | |
| visual_interpretation_generation_system_prompt, | |
| ) | |
| class Distractor(BaseModel): | |
| text: str | |
| reason: str | |
| class Distractors(BaseModel): | |
| distractors: list[Distractor] | |
| class Comment(BaseModel): | |
| option: str | |
| comment: str | |
| class CommentFormat(BaseModel): | |
| comments: list[Comment] | |
| class Judgement(BaseModel): | |
| reasoning: str | |
| correctness: int | |
| improvement: str | |
| class Question(BaseModel): | |
| reasoning: str | |
| distractors: list[str] | |
| def base64_to_image(base64_str): | |
| image_data = base64.b64decode(base64_str) | |
| image = Image.open(io.BytesIO(image_data)) | |
| return image | |
| def get_reply(client, system_prompt, user_prompt, image_base64, output_format): | |
| completion = client.beta.chat.completions.parse( | |
| model="gpt-4o", | |
| messages=[ | |
| {"role": "system", "content": dedent(system_prompt)}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": dedent(user_prompt)}, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/png;base64,{image_base64}"}, | |
| }, | |
| ], | |
| }, | |
| ], | |
| response_format=output_format, | |
| # temperature=0, # Set to 0 for deterministic responses | |
| ) | |
| parsed_output = completion.choices[0].message.parsed.dict() | |
| return parsed_output | |
| def convert_to_multi_choice(client, question, answer, image_base64, reviewer): | |
| user_prompt = f""" | |
| Question: {question} | |
| Correct Answer: {answer} | |
| """ | |
| distractors_concept = get_reply( | |
| client, concept_generation_system_prompt, user_prompt, image_base64, Distractors | |
| )["distractors"] | |
| distractors_reasoning = get_reply( | |
| client, | |
| reasoning_generation_system_prompt, | |
| user_prompt, | |
| image_base64, | |
| Distractors, | |
| )["distractors"] | |
| distractors_visual_interpretation = get_reply( | |
| client, | |
| visual_interpretation_generation_system_prompt, | |
| user_prompt, | |
| image_base64, | |
| Distractors, | |
| )["distractors"] | |
| distractors_data_processing = get_reply( | |
| client, | |
| data_processing_generation_system_prompt, | |
| user_prompt, | |
| image_base64, | |
| Distractors, | |
| )["distractors"] | |
| distractors_question_bias = get_reply( | |
| client, | |
| question_bias_generation_system_prompt, | |
| user_prompt, | |
| image_base64, | |
| Distractors, | |
| )["distractors"] | |
| # print(distractors_concept) | |
| if reviewer: | |
| user_prompt = """ | |
| Question: {question} | |
| Correct Answer: {answer} | |
| Distractions and Reasonings: {distractors} | |
| """ | |
| reviews_concept = get_reply( | |
| client, | |
| review_system_prompt.format(type="conceptual"), | |
| user_prompt.format( | |
| question=question, answer=answer, distractors=distractors_concept | |
| ), | |
| image_base64, | |
| CommentFormat, | |
| )["comments"] | |
| reviews_reasoning = get_reply( | |
| client, | |
| review_system_prompt.format(type="reasoning"), | |
| user_prompt.format( | |
| question=question, answer=answer, distractors=distractors_reasoning | |
| ), | |
| image_base64, | |
| CommentFormat, | |
| )["comments"] | |
| reviews_visual_interpretation = get_reply( | |
| client, | |
| review_system_prompt.format(type="visual interpretation"), | |
| user_prompt.format( | |
| question=question, | |
| answer=answer, | |
| distractors=distractors_visual_interpretation, | |
| ), | |
| image_base64, | |
| CommentFormat, | |
| )["comments"] | |
| reviews_data_processing = get_reply( | |
| client, | |
| review_system_prompt.format(type="data processing"), | |
| user_prompt.format( | |
| question=question, | |
| answer=answer, | |
| distractors=distractors_data_processing, | |
| ), | |
| image_base64, | |
| CommentFormat, | |
| )["comments"] | |
| reviews_question_bias = get_reply( | |
| client, | |
| review_system_prompt.format(type="question bias"), | |
| user_prompt.format( | |
| question=question, answer=answer, distractors=distractors_question_bias | |
| ), | |
| image_base64, | |
| CommentFormat, | |
| )["comments"] | |
| # print(reviews_concept) | |
| user_prompt = """ | |
| Question: {question} | |
| Correct Answer: {answer} | |
| Distractions and Reviewer Comments: {reviews} | |
| """ | |
| distractors_concept = get_reply( | |
| client, | |
| refine_system_prompt_concept, | |
| user_prompt.format( | |
| question=question, answer=answer, reviews=reviews_concept | |
| ), | |
| image_base64, | |
| Distractors, | |
| )["distractors"] | |
| distractors_reasoning = get_reply( | |
| client, | |
| refine_system_prompt_reason, | |
| user_prompt.format( | |
| question=question, answer=answer, reviews=reviews_reasoning | |
| ), | |
| image_base64, | |
| Distractors, | |
| )["distractors"] | |
| distractors_visual_interpretation = get_reply( | |
| client, | |
| refine_system_prompt_visual, | |
| user_prompt.format( | |
| question=question, answer=answer, reviews=reviews_visual_interpretation | |
| ), | |
| image_base64, | |
| Distractors, | |
| )["distractors"] | |
| distractors_data_processing = get_reply( | |
| client, | |
| refine_system_prompt_data, | |
| user_prompt.format( | |
| question=question, answer=answer, reviews=reviews_data_processing | |
| ), | |
| image_base64, | |
| Distractors, | |
| )["distractors"] | |
| distractors_question_bias = get_reply( | |
| client, | |
| refine_system_prompt_question_bias, | |
| user_prompt.format( | |
| question=question, answer=answer, reviews=reviews_question_bias | |
| ), | |
| image_base64, | |
| Distractors, | |
| )["distractors"] | |
| # print(distractors_concept) | |
| distractors = ( | |
| distractors_concept | |
| + distractors_reasoning | |
| + distractors_visual_interpretation | |
| + distractors_data_processing | |
| + distractors_question_bias | |
| ) | |
| user_prompt = f""" | |
| Question: {question} | |
| Correct Answer: {answer} | |
| All Distractors: {distractors} | |
| """ | |
| distractors = get_reply( | |
| client, fusion_generation_system_prompt, user_prompt, image_base64, Distractors | |
| )["distractors"] | |
| return distractors | |
| def judge_multichoice_correctness_with_image( | |
| client, question, choices, answer, image_base64 | |
| ): | |
| user_prompt = f""" | |
| Question: {question} | |
| Choices: {choices} | |
| Correct Answer: {answer} | |
| """ | |
| response = get_reply( | |
| client, | |
| evaluator_system_prompt, | |
| user_prompt, | |
| image_base64, | |
| Judgement, | |
| ) | |
| return response | |
| def improve_multichoice_correctness_with_image( | |
| client, | |
| question, | |
| choices, | |
| answer, | |
| issue, | |
| improvement, | |
| image_base64, | |
| ): | |
| user_prompt = f""" | |
| Question: {question} | |
| Choices: {choices} | |
| Correct Answer: {answer} | |
| Identified Issues: {issue} | |
| Suggested Improvements: {improvement} | |
| """ | |
| response = get_reply( | |
| client, | |
| refiner_system_prompt, | |
| user_prompt, | |
| image_base64, | |
| Question, | |
| ) | |
| return response | |
| def process_one_question(api_key, image, question, answer, components): | |
| reviewer = "Reviewer" in components | |
| refiner = "Refiner" in components | |
| pil_image = Image.fromarray(image) | |
| buffer = io.BytesIO() | |
| pil_image.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| image_base64 = base64.b64encode(buffer.read()).decode("utf-8") | |
| random.seed(1234) | |
| client = OpenAI(api_key=api_key) | |
| distactors = convert_to_multi_choice( | |
| client, question, answer, image_base64, reviewer | |
| ) | |
| choices = [item["text"] for item in distactors] + [answer] | |
| random.shuffle(choices) | |
| if refiner: | |
| judgement = judge_multichoice_correctness_with_image( | |
| client, question, choices, answer, image_base64 | |
| ) | |
| distractors = improve_multichoice_correctness_with_image( | |
| client, | |
| question, | |
| choices, | |
| answer, | |
| judgement["reasoning"], | |
| judgement["improvement"], | |
| image_base64, | |
| ) | |
| choices = distractors["distractors"] + [answer] | |
| random.shuffle(choices) | |
| output = f"Question: {question}\n\nA. {choices[0]}\nB. {choices[1]}\nC. {choices[2]}\nD. {choices[3]}\n\nAnswer: {'ABCD'[choices.index(answer)]}" | |
| return output | |
| def main_gradio(): | |
| interface = gr.Interface( | |
| fn=process_one_question, | |
| inputs=[ | |
| gr.Textbox(label="OpenAI API Key"), | |
| gr.Image(label="Upload an Image"), | |
| gr.Textbox(label="Question"), | |
| gr.Textbox(label="Answer"), | |
| gr.CheckboxGroup(["Reviewer", "Refiner"], label="Components"), | |
| ], | |
| outputs=gr.Textbox(label="Output"), | |
| title="AutoConverter: Automated Generation of Challenging Multiple-Choice Questions for Vision Language Model Evaluation", | |
| ) | |
| interface.launch() | |
| if __name__ == "__main__": | |
| main_gradio() | |