Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import ast | |
| import schema | |
| import csv | |
| import json | |
| from pathlib import Path | |
| import random | |
| from typing import TYPE_CHECKING | |
| from uw_programmatic.base_machine import UWBaseMachine | |
| if TYPE_CHECKING: | |
| from griptape.tools import BaseTool | |
| class UWMachine(UWBaseMachine): | |
| """State machine with GOAP""" | |
| def tools(self) -> dict[str, BaseTool]: | |
| return {} | |
| def start_machine(self) -> None: | |
| """Starts the machine.""" | |
| # Clear input history. | |
| # Clear csv file | |
| self.retrieve_vector_stores() | |
| self.send("enter_first_state") | |
| def on_enter_gather_parameters(self) -> None: | |
| # Reinitialzes the state machine | |
| self.current_question_count = 0 | |
| self.give_up_count = 0 | |
| self.question_list = [] | |
| self.rejected_questions = [] | |
| # The first state: Listens for Gradio and then gives us the parameters to search for. | |
| # Reinitializes the Give Up counter. | |
| def on_event_gather_parameters(self, event_: dict) -> None: | |
| event_source = event_["type"] | |
| event_value = event_["value"] | |
| match event_source: | |
| case "user_input": | |
| parameters = event_value | |
| self.page_range = parameters["page_range"] | |
| self.question_number = parameters["question_number"] | |
| self.taxonomy = parameters["taxonomy"] | |
| self.errored = False | |
| self.send("next_state") | |
| case "griptape_event": | |
| if event_value["structure_id"] == "create_question_workflow": | |
| pass | |
| case _: | |
| err_msg = f"Unexpected Transition Event ID: {event_value}." | |
| raise ValueError(err_msg) | |
| # Checks if there have not been any new questions generated 3 tries in a row | |
| # If # of questions is the same as the # of questions required - sends to end. | |
| def on_enter_evaluate_q_count(self) -> None: | |
| if len(self.question_list) <= self.current_question_count: | |
| self.give_up_count += 1 | |
| else: | |
| self.current_question_count = len(self.question_list) | |
| self.give_up_count = 0 | |
| if self.give_up_count >= 3: | |
| self.send("finish_state") # go to output questions | |
| return | |
| if len(self.question_list) >= self.question_number: | |
| self.send("finish_state") # go to output questions | |
| else: | |
| self.send("next_state") # go to need more questions | |
| # Necessary for state machine to not throw errors | |
| def on_event_evaluate_q_count(self, event_: dict) -> None: | |
| pass | |
| def on_enter_need_more_q(self) -> None: | |
| # Create the entire workflow to create another question. | |
| self.get_questions_workflow().run() | |
| # Returns the output of the workflow - a ListArtifact of TextArtifacts of questions. | |
| # Question, Answer, Wrong Answers, Taxonomy, Page Number | |
| def on_event_need_more_q(self, event_: dict) -> None: | |
| event_source = event_["type"] | |
| event_value = event_["value"] | |
| match event_source: | |
| case "griptape_event": | |
| event_type = event_value["type"] | |
| match event_type: | |
| case "FinishStructureRunEvent": | |
| structure_id = event_value["structure_id"] | |
| match structure_id: | |
| case "create_question_workflow": | |
| # TODO: Can you use task.output_schema on a workflow? | |
| values = event_value["output_task_output"]["value"] | |
| questions = [ | |
| ast.literal_eval(question["value"]) | |
| for question in values | |
| ] | |
| self.most_recent_questions = ( | |
| questions # This is a ListArtifact | |
| ) | |
| self.send("next_state") | |
| case _: | |
| print(f"Error:{event_} ") | |
| case _: | |
| print(f"Unexpected: {event_}") | |
| # Merges the existing and new questions and sends to similarity auditor to get rid of similar questions. | |
| def on_enter_assess_generated_q(self) -> None: | |
| merged_list = [*self.question_list, *self.most_recent_questions] | |
| prompt = f"{merged_list}" | |
| similarity_auditor = self.get_structure("similarity_auditor") | |
| similarity_auditor.task.output_schema = schema.Schema( | |
| { | |
| "list": schema.Schema( | |
| [ | |
| { | |
| "Question": str, | |
| "Answer": str, | |
| "Wrong Answers": schema.Schema([str]), | |
| "Page": str, | |
| "Taxonomy": str, | |
| } | |
| ] | |
| ) | |
| } | |
| ) | |
| similarity_auditor.run(prompt) | |
| # Sets the returned question list (with similar questions wiped) equal to self.question_list | |
| def on_event_assess_generated_q(self, event_: dict) -> None: | |
| event_source = event_["type"] | |
| event_value = event_["value"] | |
| match event_source: | |
| case "griptape_event": | |
| event_type = event_value["type"] | |
| match event_type: | |
| case "FinishStructureRunEvent": | |
| structure_id = event_value["structure_id"] | |
| match structure_id: | |
| case "similarity_auditor": | |
| new_question_list = event_value["output_task_output"][ | |
| "value" | |
| ]["list"] | |
| self.question_list = new_question_list | |
| self.send("next_state") # go to Evaluate Q Count | |
| # Writes and saves a csv in the correct format to outputs/professor_guide.csv | |
| def on_enter_output_q(self) -> None: | |
| file_path = Path.cwd().joinpath("outputs/professor_guide.csv") | |
| file_path.parent.mkdir(parents=True, exist_ok=True) | |
| with file_path.open("w+", newline="") as file: | |
| writer = csv.writer(file) | |
| for question in self.question_list: | |
| new_row = ["MC", "", 1] | |
| new_row.append(question["Question"]) | |
| wrong_answers = list(question["Wrong Answers"]) | |
| column = random.randint(1, len(wrong_answers) + 1) | |
| new_row.append(column) | |
| for i in range(1, len(wrong_answers) + 2): | |
| if i == column: | |
| new_row.append(question["Answer"]) | |
| else: | |
| new_row.append(wrong_answers.pop()) | |
| new_row.append("'"+question["Page"]) | |
| new_row.append(question["Taxonomy"]) | |
| writer.writerow(new_row) | |
| if self.give_up_count == 3: | |
| writer.writerow( | |
| [ | |
| "Too many rejected questions.", | |
| ] | |
| ) | |
| rejected_path = Path.cwd().joinpath("outputs/rejected_list.csv") | |
| with rejected_path.open("w+", newline="") as rejected_file: | |
| writer = csv.writer(rejected_file) | |
| for question in self.rejected_questions: | |
| writer.writerow(question.values()) | |
| self.send("next_state") # back to gather_parameters | |
| # Necessary to prevent errors being thrown from state machine | |
| def on_event_output_q(self, event_: dict) -> None: | |
| pass | |