Spaces:
Build error
Build error
| import gradio as gr | |
| from datasets import load_dataset | |
| from PIL import Image | |
| import json | |
| import torch | |
| from torchvision import transforms | |
| from transformers import DonutProcessor, VisionEncoderDecoderModel | |
| # import subprocess | |
| # # Install mlflow and dagshub without dependencies | |
| # subprocess.run(['pip', 'install', '--no-deps', 'mlflow']) | |
| # subprocess.run(['pip', 'install', '--no-deps', 'dagshub']) | |
| import dagshub | |
| import mlflow | |
| import time | |
| import os | |
| # from kaggle_secrets import UserSecretsClient | |
| # user_secrets = UserSecretsClient() | |
| # token = user_secrets.get_secret("dags_hub_token") | |
| # from google.colab import userdata | |
| # token = userdata.get('dags_hub_token') | |
| token = os.getenv('dags_hub_token') | |
| dagshub.auth.add_app_token(token) | |
| dagshub.init(repo_owner='zaheramasha', | |
| repo_name='Finetuning_paligemma_Zaka_capstone', | |
| mlflow=True) | |
| # Define the MLflow run ID and artifact path | |
| run_id = "c41cfd149a8c44f3a92d8e0f1253af35" # Donut model trained on the PyvizAndMarkMap dataset for 27 epochs reaching a train loss of 0.168 | |
| run_id = "89bafd5e525a4d3e9d004e13c9574198" # Donut model trained on the PyvizAndMarkMap dataset for 27 + 51 = 78 epochs reaching a train loss of 0.0353. This run was a continuation of the 27 epoch one | |
| artifact_path = "Donut_model/model" | |
| # Create the model URI using the run ID and artifact path | |
| model_uri = f"runs:/{run_id}/{artifact_path}" | |
| print(mlflow.artifacts.list_artifacts(run_id=run_id, artifact_path=artifact_path)) | |
| # Load the model and processors from the MLflow artifact | |
| # loaded_model_bundle = mlflow.transformers.load_model(artifact_path=artifact_path, run_id=run_id) | |
| # for the 20 epochs trained model | |
| model_uri = f"mlflow-artifacts:/0a5d0550f55c4169b80cd6439556be8b/c41cfd149a8c44f3a92d8e0f1253af35/artifacts/Donut_model" | |
| # for the fully 70 epochs trained model | |
| model_uri = f"mlflow-artifacts:/17c375f6eab34c63b2a2e7792803132e/89bafd5e525a4d3e9d004e13c9574198/artifacts/Donut_model" | |
| loaded_model_bundle = mlflow.transformers.load_model(model_uri=model_uri, device='cpu')#'cuda') | |
| model = loaded_model_bundle.model | |
| processor = DonutProcessor(tokenizer=loaded_model_bundle.tokenizer, feature_extractor=loaded_model_bundle.feature_extractor, image_processor=loaded_model_bundle.image_processor) | |
| print(model.config.encoder.image_size) | |
| print(model.config.decoder.max_length) | |
| import json | |
| import random | |
| from typing import Any, List, Tuple, Dict | |
| import torch | |
| from torch.utils.data import Dataset | |
| from datasets import load_dataset, DatasetDict, concatenate_datasets | |
| from PIL import Image, ImageFilter | |
| from torchvision import transforms | |
| import re | |
| # Load and split the dataset | |
| Pyviz_dataset = load_dataset("Zaherrr/OOP_KG_Pyviz_Synthetic_Dataset", revision="Sorted_edges") | |
| MarkMap_dataset = load_dataset("Zaherrr/OOP_KG_MarkMap_Synthetic_Dataset") | |
| combined_dataset = concatenate_datasets([Pyviz_dataset['data'], MarkMap_dataset['data']]) | |
| train_test_split = combined_dataset.train_test_split(test_size=0.2, seed=42) | |
| train_val_split = train_test_split["train"].train_test_split(test_size=0.125, seed=42) | |
| split_dataset = DatasetDict( | |
| { | |
| "train": train_val_split["train"], | |
| "val": train_val_split["test"], | |
| "test": train_test_split["test"], | |
| } | |
| ) | |
| def reshape_json_data_to_fit_visualize_graph(graph_data): | |
| nodes = graph_data["nodes"] | |
| edges = graph_data["edges"] | |
| transformed_nodes = [ | |
| {"id": nodes["id"][idx], "label": nodes["label"][idx]} | |
| for idx in range(len(nodes["id"])) | |
| ] | |
| transformed_edges = [ | |
| {"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"} | |
| for idx in range(len(edges["source"])) | |
| ] | |
| return {"nodes": transformed_nodes, "edges": transformed_edges} | |
| def from_json_like_to_xml_like(data): | |
| def parse_nodes(nodes): | |
| node_elements = [] | |
| for node in nodes: | |
| label = node["label"] | |
| node_elements.append(f'<n id="{node["id"]}">{label}</n>') | |
| return "<nodes>\n" + "".join(node_elements) + "\n</nodes>" | |
| def parse_edges(edges): | |
| edge_elements = [] | |
| for edge in edges: | |
| edge_elements.append(f'<e src="{edge["source"]}" tgt="{edge["target"]}"/>') | |
| return "<edges>\n" + "".join(edge_elements) + "\n</edges>" | |
| nodes_xml = parse_nodes(data["nodes"]) | |
| edges_xml = parse_edges(data["edges"]) | |
| return nodes_xml + "\n" + edges_xml | |
| # function to shuffle the nodes on the fly in an attempt to reduce the bias from random node extraction | |
| def flexible_node_shuffle(sequence): | |
| # Split the sequence into nodes and edges | |
| nodes_match = re.search(r'<nodes>(.*?)</nodes>', sequence, re.DOTALL) | |
| edges_match = re.search(r'<edges>(.*?)</edges>', sequence, re.DOTALL) | |
| if not nodes_match or not edges_match: | |
| print("Error: Could not find nodes or edges in the sequence.") | |
| return sequence | |
| nodes_content = nodes_match.group(1) | |
| edges_content = edges_match.group(1) | |
| # Extract individual nodes | |
| nodes = re.findall(r'<n id="(\d+)">(.*?)</n>', nodes_content, re.DOTALL) | |
| # Shuffle the nodes | |
| random.shuffle(nodes) | |
| # Create a mapping of old ids to new ids | |
| id_mapping = {old_id: str(new_id) for new_id, (old_id, _) in enumerate(nodes, start=1)} | |
| # Reconstruct the nodes section with new ids | |
| new_nodes_content = "".join(f'<n id="{new_id}">{content}</n>' for new_id, (_, content) in enumerate(nodes, start=1)) | |
| # Extract and update edge information | |
| edges = re.findall(r'<e src="(\d+)" tgt="(\d+)"/>', edges_content) | |
| new_edges = [] | |
| for src, tgt in edges: | |
| new_src = int(id_mapping[src]) | |
| new_tgt = int(id_mapping[tgt]) | |
| # Append edge as tuple (original_src, original_tgt) | |
| new_edges.append((new_src, new_tgt)) | |
| # Sort edges: first by the new src node id, then by the new tgt node id (preserving the original direction) | |
| new_edges.sort(key=lambda x: (min(x[0], x[1]), max(x[0], x[1]))) | |
| # Reconstruct the edges section, preserving original direction | |
| new_edges_content = "".join(f'<e src="{src}" tgt="{tgt}"/>' if src < tgt else f'<e src="{tgt}" tgt="{src}"/>' for src, tgt in new_edges) | |
| # Reconstruct the full sequence | |
| new_sequence = f'<nodes><newline>{new_nodes_content}<newline></nodes><newline><edges><newline>{new_edges_content}<newline></edges>' | |
| return new_sequence | |
| class Sharpen: | |
| def __call__(self, img): | |
| return img.filter(ImageFilter.SHARPEN) | |
| # with the graph edit distance validation | |
| import re | |
| from nltk import edit_distance | |
| import numpy as np | |
| import torch | |
| import pytorch_lightning as pl | |
| import mlflow | |
| import networkx as nx | |
| import Levenshtein | |
| import xml.etree.ElementTree as ET | |
| import multiprocessing | |
| import logging | |
| from torch.optim.lr_scheduler import LambdaLR | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # for the node matching and reordering to align with the ground truth graph | |
| def match_nodes_by_label(G_pred, G_gt): | |
| """Match nodes from predicted graph to ground truth graph based on label similarity.""" | |
| node_mapping = {} | |
| for n_pred, pred_data in G_pred.nodes(data=True): | |
| best_match = None | |
| best_score = float('inf') # Levenshtein is a distance metric, lower is better | |
| for n_gt, gt_data in G_gt.nodes(data=True): | |
| sim_score = DonutModelPLModule.normalized_levenshtein(pred_data['label'], gt_data['label']) | |
| if sim_score < best_score: | |
| best_score = sim_score | |
| best_match = n_gt | |
| if best_match: | |
| node_mapping[n_pred] = best_match | |
| return node_mapping | |
| # also for the reodering | |
| def rebuild_graph_with_mapped_nodes(G_pred, node_mapping): | |
| """Rebuild the predicted graph with nodes aligned to the ground truth.""" | |
| G_aligned = nx.Graph() | |
| for node_pred, node_gt in node_mapping.items(): | |
| G_aligned.add_node(node_gt, label=G_pred.nodes[node_pred]['label']) | |
| for u, v in G_pred.edges(): | |
| if u in node_mapping and v in node_mapping: | |
| G_aligned.add_edge(node_mapping[u], node_mapping[v]) | |
| return G_aligned | |
| class DonutModelPLModule(pl.LightningModule): | |
| def __init__(self, config, processor, model): | |
| super().__init__() | |
| self.config = config | |
| self.processor = processor | |
| self.model = model | |
| self.train_loss_epoch_total = 0.0 | |
| self.val_loss_epoch_total = 0.0 | |
| self.train_batch_count = 0 | |
| self.val_batch_count = 0 | |
| self.edit_distance_scores = [] | |
| self.graph_metrics = { | |
| 'fast_graph_similarity': [], | |
| 'node_label_similarity': [], | |
| 'edge_similarity': [], | |
| 'degree_sequence_similarity': [], | |
| 'node_coverage': [], | |
| 'edge_precision': [], | |
| 'edge_recall': [] | |
| } | |
| self.lr = config["lr"] | |
| self.warmup_steps = config["warmup_steps"] | |
| def training_step(self, batch, batch_idx): | |
| pixel_values, labels, _ = batch | |
| outputs = self.model(pixel_values, labels=labels) | |
| loss = outputs.loss | |
| self.train_loss_epoch_total += loss.item() | |
| self.train_batch_count += 1 | |
| self.log("train_loss", loss, prog_bar=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx, dataset_idx=0): | |
| pixel_values, labels, answers = batch | |
| outputs = self.model(pixel_values, labels=labels) | |
| val_loss = outputs.loss | |
| self.val_loss_epoch_total += val_loss.item() | |
| self.val_batch_count += 1 | |
| self.log("val_loss", val_loss) | |
| if (self.current_epoch + 1) % self.config.get("edit_distance_validation_frequency") == 0: | |
| logger.info(f'Finished epoch: {self.current_epoch + 1}') | |
| print(f'Finished epoch: {self.current_epoch + 1}') | |
| batch_size = pixel_values.shape[0] | |
| decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device) | |
| try: | |
| outputs = self.model.generate(pixel_values, | |
| decoder_input_ids=decoder_input_ids, | |
| max_length=self.config.get("max_length", 512), | |
| early_stopping=True, | |
| pad_token_id=self.processor.tokenizer.pad_token_id, | |
| eos_token_id=self.processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| num_beams=1, | |
| bad_words_ids=[[self.processor.tokenizer.unk_token_id]], | |
| return_dict_in_generate=True,) | |
| predictions = self.process_predictions(outputs) | |
| logger.info('Calculating graph metrics') | |
| print('Calculating graph metrics') | |
| levenshtein_scores, graph_scores = self.calculate_metrics(predictions, answers) | |
| logger.info('Finished calculating graph metrics') | |
| print('Finished calculating graph metrics') | |
| self.edit_distance_scores.append(np.mean(levenshtein_scores)) | |
| for metric in self.graph_metrics: | |
| self.graph_metrics[metric].append(np.mean([score[metric] for score in graph_scores if metric in score])) | |
| self.log("val_edit_distance", np.mean(levenshtein_scores), prog_bar=True) | |
| for metric in self.graph_metrics: | |
| self.log(f"val_{metric}", self.graph_metrics[metric][-1], prog_bar=True) | |
| except Exception as e: | |
| logger.error(f"Error in validation step: {str(e)}") | |
| print(f"Error in validation step: {str(e)}") | |
| def process_predictions(self, outputs): | |
| predictions = [] | |
| for seq in self.processor.tokenizer.batch_decode(outputs.sequences): | |
| try: | |
| seq = ( | |
| seq.replace(self.processor.tokenizer.eos_token, "") | |
| .replace(self.processor.tokenizer.pad_token, "") | |
| .replace('<n id=" ', '<n id="') | |
| .replace('src=" ', 'src="') | |
| .replace('tgt=" ', 'tgt="') | |
| .replace('<newline>', '\n') | |
| ) | |
| seq = re.sub(r"<s>", "", seq, count=1).strip() | |
| seq = seq.replace("<s>", "") | |
| predictions.append(seq) | |
| except Exception as e: | |
| logger.error(f"Error processing prediction: {str(e)}") | |
| print(f"Error processing prediction: {str(e)}") | |
| predictions.append("") # Append empty string if processing fails | |
| return predictions | |
| def calculate_metrics(self, predictions, answers): | |
| levenshtein_scores = [] | |
| graph_scores = [] | |
| for pred, answer in zip(predictions, answers): | |
| try: | |
| pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred) | |
| answer = answer.replace(self.processor.tokenizer.bos_token, "").replace(self.processor.tokenizer.eos_token, "").replace("<newline>", "\n") | |
| edit_dist = edit_distance(pred, answer) / max(len(pred), len(answer)) | |
| logger.info(f"Prediction: {pred}") | |
| logger.info(f" Answer: {answer}") | |
| logger.info(f" Normed ED: {edit_dist}") | |
| print(f"Prediction: {pred}") | |
| print(f" Answer: {answer}") | |
| print(f" Normed ED: {edit_dist}") | |
| levenshtein_scores.append(edit_dist) | |
| pred_graph = self.create_graph_from_string(pred) | |
| answer_graph = self.create_graph_from_string(answer) | |
| # Added this to reorder the predicted graphs ignoring the node order for better validation | |
| # Match nodes based on labels and reorder | |
| node_mapping = match_nodes_by_label(pred_graph, answer_graph) | |
| pred_graph_aligned = rebuild_graph_with_mapped_nodes(pred_graph, node_mapping) | |
| # Compare the aligned graphs | |
| # graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60)) | |
| logger.info('Calculating the GED') | |
| print('Calculating the GED') | |
| # graph_scores.append(self.compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60)) | |
| graph_scores.append(self.compare_graphs_with_timeout(pred_graph_aligned, answer_graph, timeout=60)) | |
| logger.info('Got the GED results') | |
| print('Got the GED results') | |
| except Exception as e: | |
| logger.error(f"Error calculating metrics: {str(e)}") | |
| print(f"Error calculating metrics: {str(e)}") | |
| levenshtein_scores.append(1.0) # Worst possible score | |
| graph_scores.append({metric: 0.0 for metric in self.graph_metrics}) # Worst possible scores | |
| return levenshtein_scores, graph_scores | |
| def compare_graphs_with_timeout(pred_graph, answer_graph, timeout=60): | |
| def wrapper(return_dict): | |
| return_dict['result'] = DonutModelPLModule.compare_graphs(pred_graph, answer_graph) | |
| manager = multiprocessing.Manager() | |
| return_dict = manager.dict() | |
| p = multiprocessing.Process(target=wrapper, args=(return_dict,)) | |
| p.start() | |
| p.join(timeout) | |
| if p.is_alive(): | |
| logger.warning('Graph comparison timed out. Returning default values.') | |
| print('Graph comparison timed out. Returning default values.') | |
| p.terminate() | |
| p.join() | |
| return { | |
| "fast_graph_similarity": 0.0, | |
| "node_label_similarity": 0.0, | |
| "edge_similarity": 0.0, | |
| "degree_sequence_similarity": 0.0, | |
| "node_coverage": 0.0, | |
| "edge_precision": 0.0, | |
| "edge_recall": 0.0 | |
| } | |
| else: | |
| return return_dict.get('result', { | |
| "fast_graph_similarity": 0.0, | |
| "node_label_similarity": 0.0, | |
| "edge_similarity": 0.0, | |
| "degree_sequence_similarity": 0.0, | |
| "node_coverage": 0.0, | |
| "edge_precision": 0.0, | |
| "edge_recall": 0.0 | |
| }) | |
| def create_graph_from_string(xml_string): | |
| G = nx.Graph() | |
| try: | |
| # Extract nodes | |
| nodes = re.findall(r'<n id="(\d+)">(.*?)</n>', xml_string, re.DOTALL) | |
| for node_id, label in nodes: | |
| G.add_node(node_id, label=label.lower()) | |
| # Extract edges | |
| edges = re.findall(r'<e src="(\d+)" tgt="(\d+)"/>', xml_string) | |
| for src, tgt in edges: | |
| G.add_edge(src, tgt) | |
| except Exception as e: | |
| logger.error(f"Error creating graph from string: {str(e)}") | |
| print(f"Error creating graph from string: {str(e)}") | |
| return G | |
| def normalized_levenshtein(s1, s2): | |
| distance = Levenshtein.distance(s1, s2) | |
| max_length = max(len(s1), len(s2)) | |
| return distance / max_length if max_length > 0 else 0 | |
| def calculate_node_coverage(G1, G2, threshold=0.2): | |
| matched_nodes = 0 | |
| for n1 in G1.nodes(data=True): | |
| if any(DonutModelPLModule.normalized_levenshtein(n1[1]['label'], n2[1]['label']) <= threshold | |
| for n2 in G2.nodes(data=True)): | |
| matched_nodes += 1 | |
| return matched_nodes / max(len(G1), len(G2)) | |
| def node_label_similarity(G1, G2): | |
| labels1 = list(nx.get_node_attributes(G1, 'label').values()) | |
| labels2 = list(nx.get_node_attributes(G2, 'label').values()) | |
| total_similarity = 0 | |
| for label1 in labels1: | |
| similarities = [1 - DonutModelPLModule.normalized_levenshtein(label1, label2) for label2 in labels2] | |
| total_similarity += max(similarities) if similarities else 0 | |
| return total_similarity / len(labels1) if labels1 else 0 | |
| def edge_similarity(G1, G2): | |
| return len(set(G1.edges()) & set(G2.edges())) / max(len(G1.edges()), len(G2.edges())) if max(len(G1.edges()), len(G2.edges())) > 0 else 1 | |
| def degree_sequence_similarity(G1, G2): | |
| seq1 = sorted([d for n, d in G1.degree()], reverse=True) | |
| seq2 = sorted([d for n, d in G2.degree()], reverse=True) | |
| # If either sequence is empty, return 0 similarity | |
| if not seq1 or not seq2: | |
| return 0.0 | |
| # Padding sequences to make them the same length | |
| max_len = max(len(seq1), len(seq2)) | |
| seq1 += [0] * (max_len - len(seq1)) | |
| seq2 += [0] * (max_len - len(seq2)) | |
| # Calculate degree sequence similarity | |
| diff_sum = sum(abs(x - y) for x, y in zip(seq1, seq2)) | |
| # Return similarity, handle edge case where the sum of degrees is zero | |
| return 1 - diff_sum / (2 * sum(seq1)) if sum(seq1) > 0 else 0.0 | |
| def fast_graph_similarity(G1, G2): | |
| node_sim = DonutModelPLModule.node_label_similarity(G1, G2) | |
| edge_sim = DonutModelPLModule.edge_similarity(G1, G2) | |
| degree_sim = DonutModelPLModule.degree_sequence_similarity(G1, G2) | |
| return (node_sim + edge_sim + degree_sim) / 3 | |
| def compare_graphs(G1, G2): | |
| try: | |
| node_coverage = DonutModelPLModule.calculate_node_coverage(G1, G2) | |
| G1_edges = set(G1.edges()) | |
| G2_edges = set(G2.edges()) | |
| correct_edges = len(G1_edges & G2_edges) | |
| edge_precision = correct_edges / len(G2_edges) if G2_edges else 0 | |
| edge_recall = correct_edges / len(G1_edges) if G1_edges else 0 | |
| return { | |
| "fast_graph_similarity": DonutModelPLModule.fast_graph_similarity(G1, G2), | |
| "node_label_similarity": DonutModelPLModule.node_label_similarity(G1, G2), | |
| "edge_similarity": DonutModelPLModule.edge_similarity(G1, G2), | |
| "degree_sequence_similarity": DonutModelPLModule.degree_sequence_similarity(G1, G2), | |
| "node_coverage": node_coverage, | |
| "edge_precision": edge_precision, | |
| "edge_recall": edge_recall | |
| } | |
| except Exception as e: | |
| logger.error(f"Error comparing graphs: {str(e)}") | |
| print(f"Error comparing graphs: {str(e)}") | |
| return { | |
| "fast_graph_similarity": 0.0, | |
| "node_label_similarity": 0.0, | |
| "edge_similarity": 0.0, | |
| "degree_sequence_similarity": 0.0, | |
| "node_coverage": 0.0, | |
| "edge_precision": 0.0, | |
| "edge_recall": 0.0 | |
| } | |
| def configure_optimizers(self): | |
| # Define the optimizer | |
| optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) | |
| # Define the warmup + decay scheduler | |
| def lr_lambda(current_step): | |
| if current_step < self.warmup_steps: | |
| return float(current_step) / float(max(1, self.warmup_steps)) | |
| return 1.0 # You can replace this with a decay function like exponential decay | |
| scheduler = LambdaLR(optimizer, lr_lambda) | |
| return { | |
| 'optimizer': optimizer, | |
| 'lr_scheduler': { | |
| 'scheduler': scheduler, | |
| 'interval': 'step', # Update the learning rate after every training step | |
| 'frequency': 1, # How often the scheduler is called (every step) | |
| } | |
| } | |
| def on_validation_epoch_end(self): | |
| avg_val_loss = self.val_loss_epoch_total / self.val_batch_count | |
| mlflow.log_metric("validation_crossentropy_loss", avg_val_loss, step=self.current_epoch) | |
| self.val_loss_epoch_total = 0.0 | |
| self.val_batch_count = 0 | |
| if (self.current_epoch + 1) % self.config.get("edit_distance_validation_frequency") == 0: | |
| if self.edit_distance_scores: | |
| mlflow.log_metric("validation_edit_distance", self.edit_distance_scores[-1], step=self.current_epoch) | |
| for metric in self.graph_metrics: | |
| if self.graph_metrics[metric]: | |
| mlflow.log_metric(f"validation_{metric}", self.graph_metrics[metric][-1], step=self.current_epoch) | |
| print('[INFO] - Finished the validation for epoch ', self.current_epoch + 1) | |
| def on_train_epoch_end(self): | |
| print(f'[INFO] - Finished epoch {self.current_epoch + 1}') | |
| avg_train_loss = self.train_loss_epoch_total / self.train_batch_count | |
| print(f'[INFO] - Train loss: {avg_train_loss}') | |
| mlflow.log_metric("training_crossentropy_loss", avg_train_loss, step=self.current_epoch) | |
| self.train_loss_epoch_total = 0.0 | |
| self.train_batch_count = 0 | |
| if ((self.current_epoch + 1) % self.config.get("save_model_weights_frequency", 10)) == 0: | |
| self.save_model() | |
| def on_fit_end(self): | |
| self.save_model() | |
| def save_model(self): | |
| model_dir = "Donut_model" | |
| os.makedirs(model_dir, exist_ok=True) | |
| self.model.save_pretrained(model_dir) | |
| print('[INFO] - Saving the model to dagshub using mlflow') | |
| mlflow.transformers.log_model( | |
| transformers_model={ | |
| "model": self.model, | |
| "feature_extractor": self.processor.feature_extractor, | |
| "image_processor": self.processor.image_processor, | |
| "tokenizer": self.processor.tokenizer | |
| }, | |
| artifact_path=model_dir, | |
| # Set task explicitly since MLflow cannot infer it from the loaded model | |
| task = "image-to-text" | |
| ) | |
| print('[INFO] - Saved the model to dagshub using mlflow') | |
| def train_dataloader(self): | |
| return train_dataloader | |
| def val_dataloader(self): | |
| return val_dataloader | |
| config = {"max_epochs":200, | |
| # "val_check_interval":0.2, # how many times we want to validate during an epoch | |
| "check_val_every_n_epoch":1, | |
| "gradient_clip_val":1.0, | |
| # "num_training_samples_per_epoch": 800, | |
| "lr":8e-4, #3e-4, #3e-5, | |
| "train_batch_sizes": [1], #[8], #[1],#[8], | |
| "val_batch_sizes": [1], | |
| # "seed":2022, | |
| "num_nodes": 1, | |
| "warmup_steps": 200, # 800/8*30/10, 10% | |
| "verbose": True, | |
| } | |
| model_module = DonutModelPLModule(config, processor, model) | |
| # Load dataset | |
| dataset = split_dataset['test'] | |
| # Set up device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| class Sharpen: | |
| def __call__(self, img): | |
| return img.filter(ImageFilter.SHARPEN) | |
| def preprocess_image(image): | |
| # Convert to PIL Image if it's not already | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| # Apply sharpening | |
| sharpen = Sharpen() | |
| sharpened_image = sharpen(image) | |
| return sharpened_image | |
| def perform_inference(image): | |
| # Preprocess the image | |
| inputs = processor(images=image, return_tensors="pt") | |
| pixel_values = inputs.pixel_values.to(device) | |
| # Prepare decoder input ids | |
| batch_size = pixel_values.shape[0] | |
| decoder_input_ids = torch.full((batch_size, 1), model.config.decoder_start_token_id, device=device) | |
| # Generate output | |
| outputs = model.generate( | |
| pixel_values, | |
| decoder_input_ids=decoder_input_ids, | |
| max_length=max_length, # + 500, #512, # Adjust as needed | |
| early_stopping=True, | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| num_beams=1, | |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
| return_dict_in_generate=True, | |
| ) | |
| # Decode the output | |
| decoded_output = processor.batch_decode(outputs.sequences)[0] | |
| print("Raw model output:", decoded_output) | |
| return decoded_output | |
| def display_example(index): | |
| example = dataset[index] | |
| img = example["image"] | |
| return img, None, None | |
| def from_json_like_to_xml_like(data): | |
| def parse_nodes(nodes): | |
| node_elements = [] | |
| for node in nodes: | |
| label = node["label"] | |
| node_elements.append(f'<n id="{node["id"]}">{label}</n>') | |
| return "<nodes>\n" + "".join(node_elements) + "\n</nodes>" | |
| def parse_edges(edges): | |
| edge_elements = [] | |
| for edge in edges: | |
| edge_elements.append(f'<e src="{edge["source"]}" tgt="{edge["target"]}"/>') | |
| return "<edges>\n" + "".join(edge_elements) + "\n</edges>" | |
| nodes_xml = parse_nodes(data["nodes"]) | |
| edges_xml = parse_edges(data["edges"]) | |
| return nodes_xml + "\n" + edges_xml | |
| def reshape_json_data_to_fit_visualize_graph(graph_data): | |
| nodes = graph_data["nodes"] | |
| edges = graph_data["edges"] | |
| transformed_nodes = [ | |
| {"id": nodes["id"][idx], "label": nodes["label"][idx]} | |
| for idx in range(len(nodes["id"])) | |
| ] | |
| transformed_edges = [ | |
| {"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"} | |
| for idx in range(len(edges["source"])) | |
| ] | |
| return {"nodes": transformed_nodes, "edges": transformed_edges} | |
| def get_ground_truth(index): | |
| example = dataset[index] | |
| ground_truth = json.dumps(reshape_json_data_to_fit_visualize_graph(example)) | |
| ground_truth = from_json_like_to_xml_like(json.loads(ground_truth)) | |
| print(f'Ground truth sequence: {ground_truth}') | |
| return ground_truth | |
| def transform_image(img, index, physics_enabled): | |
| # Perform inference | |
| sequence = perform_inference(img) | |
| # Transform the sequence to graph data | |
| graph_data = transform_sequence(sequence) | |
| # Generate the graph visualization | |
| graph_html = visualize_graph(graph_data, physics_enabled) | |
| # Modify the iframe to have a fixed height | |
| graph_html = graph_html.replace('height: 100vh;', 'height: 500px;') | |
| # Convert graph_data to a formatted JSON string | |
| json_data = json.dumps(graph_data, indent=2) | |
| return graph_html, json_data, sequence | |
| import re | |
| from typing import Dict, List, Tuple | |
| def transform_sequence(sequence: str) -> Dict[str, List[Dict[str, str]]]: | |
| # Extract nodes and edges | |
| nodes_match = re.search(r'<nodes>(.*?)</nodes>', sequence, re.DOTALL) | |
| edges_match = re.search(r'<edges>(.*?)</edges>', sequence, re.DOTALL) | |
| if not nodes_match or not edges_match: | |
| raise ValueError("Invalid input sequence: nodes or edges not found") | |
| nodes_text = nodes_match.group(1) | |
| edges_text = edges_match.group(1) | |
| # Parse nodes | |
| nodes = [] | |
| for node_match in re.finditer(r'<n id="\s*(\d+)">(.*?)</n>', nodes_text): | |
| node_id, node_label = node_match.groups() | |
| nodes.append({ | |
| "id": node_id.strip(), | |
| "label": node_label.strip() | |
| }) | |
| # Parse edges | |
| edges = [] | |
| for edge_match in re.finditer(r'<e src="\s*(\d+)" tgt="\s*(\d+)"/>', edges_text): | |
| source, target = edge_match.groups() | |
| edges.append({ | |
| "source": source.strip(), | |
| "target": target.strip(), | |
| "type": "->" | |
| }) | |
| return { | |
| "nodes": nodes, | |
| "edges": edges | |
| } | |
| # function to visualize the extracted graph | |
| import json | |
| from pyvis.network import Network | |
| def create_graph(nodes, edges, physics_enabled=True): | |
| net = Network( | |
| notebook=True, | |
| height="100vh", | |
| width="100vw", | |
| bgcolor="#222222", | |
| font_color="white", | |
| cdn_resources="remote", | |
| ) | |
| for node in nodes: | |
| net.add_node( | |
| node["id"], | |
| label=node["label"], | |
| title=node["label"], | |
| color="blue" if node["label"] == "OOP" else "green", | |
| ) | |
| for edge in edges: | |
| net.add_edge(edge["source"], edge["target"], title=edge["type"]) | |
| net.force_atlas_2based( | |
| gravity=-50, | |
| central_gravity=0.01, | |
| spring_length=100, | |
| spring_strength=0.08, | |
| damping=0.4, | |
| ) | |
| options = { | |
| "nodes": {"physics": physics_enabled}, | |
| "edges": {"smooth": True}, | |
| "interaction": {"hover": True, "zoomView": True}, | |
| "physics": { | |
| "enabled": physics_enabled, | |
| "stabilization": {"enabled": True, "iterations": 200}, | |
| }, | |
| } | |
| net.set_options(json.dumps(options)) | |
| return net | |
| def visualize_graph(json_data, physics_enabled=True): | |
| if isinstance(json_data, str): | |
| data = json.loads(json_data) | |
| else: | |
| data = json_data | |
| nodes = data["nodes"] | |
| edges = data["edges"] | |
| net = create_graph(nodes, edges, physics_enabled) | |
| html = net.generate_html() | |
| html = html.replace("'", '"') | |
| html = html.replace( | |
| '<div id="mynetwork"', '<div id="mynetwork" style="height: 100vh; width: 100%;"' | |
| ) | |
| return f"""<iframe style="width: 100%; height: 100vh; border: none; margin: 0; padding: 0;" srcdoc='{html}'></iframe>""" | |
| def update_physics(json_data, physics_enabled): | |
| if json_data is None: | |
| return None | |
| data = json.loads(json_data) | |
| graph_html = visualize_graph(data, physics_enabled) | |
| graph_html = graph_html.replace('height: 100vh;', 'height: 500px;') | |
| return graph_html | |
| # function to calculate the graph similarity metrics between the prediction and the ground-truth | |
| def calculate_and_display_metrics(pred_graph, ground_truth_graph): | |
| if pred_graph is None or ground_truth_graph is None: | |
| return "Please generate a prediction and ensure a ground truth graph is available." | |
| #removing the start token from the string | |
| pred_graph = pred_graph.replace('<s>', "").replace("<newline>", "\n").replace('src=" ', 'src="').replace('tgt=" ', 'tgt="').replace('<n id=" ', '<n id="') | |
| print(f'Prediction: {pred_graph}') | |
| # Assuming the graphs are in the correct format for the calculate_metrics function | |
| metrics = model_module.calculate_metrics([pred_graph], [ground_truth_graph]) | |
| # Format the metrics for display | |
| overall_metric = metrics[0][0] | |
| detailed_metrics = metrics[1][0] | |
| # output = f"Overall Metric: {overall_metric:.4f}\n\nDetailed Metrics:\n" | |
| output = f"Detailed Metrics:\n" | |
| for key, value in detailed_metrics.items(): | |
| output += f"{key}: {value:.4f}\n" | |
| return output | |
| def create_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Knowledge Graph Visualizer with Model Inference") | |
| with gr.Row(): | |
| index_slider = gr.Slider( | |
| minimum=0, | |
| maximum=len(dataset) - 1, | |
| step=1, | |
| label="Example Index" | |
| ) | |
| with gr.Row(): | |
| image_output = gr.Image(type="pil", label="Image", height=500, interactive=False) | |
| graph_output = gr.HTML(label="Knowledge Graph") | |
| with gr.Row(): | |
| transform_button = gr.Button("Transform") | |
| physics_toggle = gr.Checkbox(label="Enable Physics", value=True) | |
| with gr.Row(): | |
| json_output = gr.Code(language="json", label="Graph JSON Data") | |
| ground_truth_output = gr.Textbox(visible=False)#gr.JSON(label="Ground Truth Graph", visible=False) | |
| predicted_raw_sequence = gr.Textbox(visible=False) | |
| with gr.Row(): | |
| metrics_button = gr.Button("Calculate Metrics") | |
| metrics_output = gr.Textbox(label="Similarity Metrics", lines=10) | |
| index_slider.change( | |
| fn=display_example, | |
| inputs=[index_slider], | |
| outputs=[image_output, graph_output, json_output], | |
| ).then( | |
| fn=get_ground_truth, | |
| inputs=[index_slider], | |
| outputs=[ground_truth_output], | |
| ) | |
| transform_button.click( | |
| fn=transform_image, | |
| inputs=[image_output, index_slider, physics_toggle], | |
| outputs=[graph_output, json_output, predicted_raw_sequence], | |
| ).then( | |
| fn=calculate_and_display_metrics, | |
| inputs=[predicted_raw_sequence, ground_truth_output], | |
| outputs=[metrics_output]#gr.Textbox(label="Metrics"), | |
| ) | |
| metrics_button.click( | |
| fn=calculate_and_display_metrics, | |
| inputs=[predicted_raw_sequence, ground_truth_output], | |
| outputs=[metrics_output], | |
| ) | |
| physics_toggle.change( | |
| fn=update_physics, | |
| inputs=[json_output, physics_toggle], | |
| outputs=[graph_output], | |
| ) | |
| return demo | |
| # Create and launch the interface | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(share=True, debug=True) |