Spaces:
Sleeping
Sleeping
| from pipeline.workflow import recognize_graph | |
| from pipeline.commons import here | |
| import os | |
| import jinja2 | |
| import pickle | |
| import pm4py | |
| from pm4py.visualization.petri_net import visualizer as pn_visualizer | |
| from collections import Counter, defaultdict | |
| from typing import List, Tuple | |
| from pipeline.models import Place, Transition, Arc | |
| from groq import Groq | |
| import os | |
| import base64 | |
| from config.path_config import ( | |
| OUTPUT_DIR, TEMPLATES_DIR, VISUALIZATIONS_DIR, PIPELINE_OUTPUT_DIR, | |
| PLACES_PKL_PATH, TRANSITIONS_PKL_PATH, ARCS_PKL_PATH, | |
| PLACES_FIXED_PKL_PATH, TRANSITIONS_FIXED_PKL_PATH, ARCS_FIXED_PKL_PATH, | |
| OUTPUT_PNML_PATH, OUTPUT_PETRIOBJ_PATH, OUTPUT_JSON_PATH, OUTPUT_PNG_PATH, OUTPUT_GV_PATH, | |
| WORKING_IMAGE_PATH, ensure_directories_exist, get_visualization_path, get_output_file_path | |
| ) | |
| def process_elements(places: List[Place], transitions: List[Transition], arcs: List[Arc]) -> Tuple[List[Place], List[Transition], List[Arc]]: | |
| # Process places to remove those with no connected arcs | |
| places_to_remove = set() | |
| for place in places: | |
| if not any(arc.source == place or arc.target == place for arc in arcs): | |
| places_to_remove.add(place) | |
| new_places = [p for p in places if p not in places_to_remove] | |
| # Remove arcs connected to removed places | |
| arcs_after_places = [arc for arc in arcs if arc.source not in places_to_remove and arc.target not in places_to_remove] | |
| # Process transitions to remove those with less than two connected arcs | |
| transitions_to_remove = set() | |
| arcs_to_remove = set() | |
| for transition in transitions: | |
| connected_arcs = [arc for arc in arcs_after_places if arc.source == transition or arc.target == transition] | |
| if len(connected_arcs) < 2: | |
| transitions_to_remove.add(transition) | |
| arcs_to_remove.update(connected_arcs) | |
| new_transitions = [t for t in transitions if t not in transitions_to_remove] | |
| arcs_after_transitions = [arc for arc in arcs_after_places if arc not in arcs_to_remove] | |
| # Adjust transitions to have both incoming and outgoing arcs | |
| for transition in new_transitions: | |
| connected_arcs = [arc for arc in arcs_after_transitions if arc.source == transition or arc.target == transition] | |
| outgoing = sum(1 for arc in connected_arcs if arc.source == transition) | |
| incoming = sum(1 for arc in connected_arcs if arc.target == transition) | |
| if outgoing == 0 and incoming >= 1: | |
| # Flip one incoming arc to outgoing | |
| for arc in connected_arcs: | |
| if arc.target == transition: | |
| arc.source, arc.target = arc.target, arc.source | |
| break | |
| elif incoming == 0 and outgoing >= 1: | |
| # Flip one outgoing arc to incoming | |
| for arc in connected_arcs: | |
| if arc.source == transition: | |
| arc.source, arc.target = arc.target, arc.source | |
| break | |
| return new_places, new_transitions, arcs_after_transitions | |
| def fix_petri_net(): | |
| """Method that checks for all the errors in the petri net, logs the errors and applies fixes, if readily available.""" | |
| with open(PLACES_PKL_PATH, "rb") as f: | |
| places = pickle.load(f) | |
| with open(TRANSITIONS_PKL_PATH, "rb") as f: | |
| transitions = pickle.load(f) | |
| with open(ARCS_PKL_PATH, "rb") as f: | |
| arcs = pickle.load(f) | |
| ### Remove duplicate ids across places, transitions and arcs | |
| all_ids = [] | |
| all_ids.extend(place.id for place in places) | |
| all_ids.extend(transition.id for transition in transitions) | |
| all_ids.extend(arc.id for arc in arcs) | |
| id_duplicates = [id for id, count in Counter(all_ids).items() if count > 1] | |
| for duplicate_id in id_duplicates: | |
| duplicate_elements = [] | |
| duplicate_elements.extend([place for place in places if place.id == duplicate_id]) | |
| duplicate_elements.extend([transition for transition in transitions if transition.id == duplicate_id]) | |
| duplicate_elements.extend([arc for arc in arcs if arc.id == duplicate_id]) | |
| print(f"Duplicate ID {duplicate_id} found in elements: {duplicate_elements}") | |
| ### Remove cycles, remove same type connections | |
| arcs = [arc for arc in arcs if type(arc.source) != type(arc.target)] | |
| ### Fix weights if any are less than 1 | |
| for arc in arcs: | |
| if arc.weight < 1: | |
| print(f"Arc found with weight less than 1: {arc}") | |
| print("Applying fix to set the weight to 1") | |
| arc.weight = 1 | |
| ### find arcs in arcs list, that have same source and the same target, and merge them into one arc, with the sum of the weights | |
| # Group arcs by their source and target | |
| arc_groups = defaultdict(list) | |
| for arc in arcs: # Create a copy of the list to safely modify original | |
| key = (arc.source.id, arc.target.id) | |
| arc_groups[key].append(arc) | |
| # For each group of arcs with same source/target, merge them | |
| for (source_id, target_id), group in arc_groups.items(): | |
| if len(group) > 1: | |
| print(f"Found {len(group)} parallel arcs between same source and target: {source_id} -> {target_id}") | |
| total_weight = sum(arc.weight for arc in group) | |
| merged_arc = group[0] | |
| merged_arc.weight = total_weight | |
| # Remove other arcs from the original list | |
| for arc in group[1:]: | |
| if arc in arcs: | |
| arcs.remove(arc) | |
| ### There should be no hanging places, every place must have at least one arc | |
| ### Every transition must have at least one input and one output arc | |
| # places, transitions, arcs = sanitize_petri_net(places, transitions, arcs) | |
| places, transitions, arcs = process_elements(places, transitions, arcs) | |
| print(f"len(places): {len(places)}") | |
| print(f"len(transitions): {len(transitions)}") | |
| print(f"len(arcs): {len(arcs)}") | |
| ### save results as pickles | |
| ensure_directories_exist() | |
| with open(PLACES_FIXED_PKL_PATH, "wb") as f: | |
| pickle.dump(places, f) | |
| with open(TRANSITIONS_FIXED_PKL_PATH, "wb") as f: | |
| pickle.dump(transitions, f) | |
| with open(ARCS_FIXED_PKL_PATH, "wb") as f: | |
| pickle.dump(arcs, f) | |
| def run_and_save_pipeline(config_path: str, image_path: str): | |
| result = recognize_graph(image_path, config_path) | |
| # Access the results | |
| places = result["places"] | |
| transitions = result["transitions"] | |
| arcs = result["arcs"] | |
| # Saving logic | |
| ensure_directories_exist() | |
| for name, img in result["visualizations"].items(): | |
| img.save(get_visualization_path(name)) | |
| with open(PLACES_PKL_PATH, "wb") as f: | |
| pickle.dump(places, f) | |
| with open(TRANSITIONS_PKL_PATH, "wb") as f: | |
| pickle.dump(transitions, f) | |
| with open(ARCS_PKL_PATH, "wb") as f: | |
| pickle.dump(arcs, f) | |
| print(f"Recognition complete. Found {len(places)} places, {len(transitions)} transitions, and {len(arcs)} arcs.") | |
| def render_diagram_to(file_type: str): | |
| """Method that renders elements into the final pnml string | |
| Parameters | |
| ---------- | |
| type: str | |
| the type of template to use | |
| Returns | |
| ------- | |
| str | |
| The string representing the final pnml model | |
| """ | |
| with open(PLACES_FIXED_PKL_PATH, "rb") as f: | |
| places = pickle.load(f) | |
| with open(TRANSITIONS_FIXED_PKL_PATH, "rb") as f: | |
| transitions = pickle.load(f) | |
| with open(ARCS_FIXED_PKL_PATH, "rb") as f: | |
| arcs = pickle.load(f) | |
| template_loader = jinja2.FileSystemLoader( | |
| searchpath=TEMPLATES_DIR | |
| ) | |
| template_env = jinja2.Environment(loader=template_loader) | |
| if file_type == "pnml": | |
| template = template_env.get_template(f"template.{file_type}.jinja") | |
| output_text = template.render({"places": places, "transitions": transitions, "arcs": arcs}) | |
| output_file_path = get_output_file_path(f"output.{file_type}") | |
| with open(output_file_path, "w", encoding="utf-8") as f: | |
| f.write(output_text) | |
| elif file_type == "petriobj": | |
| template = template_env.get_template(f"template.{file_type}.jinja") | |
| place_to_index = {place.id: index for index, place in enumerate(places)} | |
| transition_to_index = {transition.id: index for index, transition in enumerate(transitions)} | |
| output_text = template.render({"places": places, "transitions": transitions, "arcs": arcs, "place_to_index": place_to_index, "transition_to_index": transition_to_index}) | |
| output_file_path = get_output_file_path(f"output.{file_type}") | |
| with open(output_file_path, "w", encoding="utf-8") as f: | |
| f.write(output_text) | |
| else: | |
| raise ValueError(f"Invalid file type: {file_type}") | |
| return output_text | |
| def render_to_graphviz(): | |
| net, im, fm = pm4py.read_pnml(OUTPUT_PNML_PATH) | |
| gviz = pn_visualizer.apply(net, im, fm) | |
| pn_visualizer.save(gviz, OUTPUT_GV_PATH) | |
| pm4py.save_vis_petri_net(net, im, fm, OUTPUT_PNG_PATH) | |
| def render_to_json(): | |
| def encode_image(image_path): | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode('utf-8') | |
| image_path = WORKING_IMAGE_PATH | |
| base64_image = encode_image(image_path) | |
| # Load API key from environment variable | |
| groq_api_key = os.getenv('GROQ_API_KEY') | |
| if not groq_api_key: | |
| raise ValueError("GROQ_API_KEY environment variable is not set. Please check your .env file.") | |
| client = Groq(api_key=groq_api_key) | |
| completion = client.chat.completions.create( | |
| model="meta-llama/llama-4-scout-17b-16e-instruct", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "{\n \"places\": [\n {\"id\": \"string\", \"tokens\": \"integer\"}\n ],\n \"transitions\": [\n {\"id\": \"string\", \"delay\": \"number_or_string\"}\n ],\n \"arcs\": [\n {\"source\": \"string\", \"target\": \"string\", \"weight\": \"integer\"}\n ]\n}" | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "Take image of a Petri net as input and provide the textual representation of the graph in json format, according to this json template." | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64_image}", | |
| }, | |
| } | |
| ] | |
| } | |
| ], | |
| temperature=1, | |
| max_completion_tokens=2048, | |
| top_p=1, | |
| stream=False, | |
| response_format={"type": "json_object"}, | |
| stop=None, | |
| ) | |
| output_file_path = OUTPUT_JSON_PATH | |
| with open(output_file_path, "w", encoding="utf-8") as f: | |
| f.write(completion.choices[0].message.content) | |
| if __name__ == "__main__": | |
| # run_and_save_pipeline(config_path=here("../data/config.yaml"), image_path=here("../data/local/mid_petri_2.png")) | |
| fix_petri_net() | |
| ## the next steps should be done in parallel | |
| render_diagram_to("pnml") | |
| render_diagram_to("petriobj") | |
| render_to_graphviz() | |
| render_to_json() |