Spaces:
Sleeping
Sleeping
| # aggile.py | |
| class Aggile: | |
| """ | |
| Graph generator for plain text | |
| """ | |
| def __init__(self, client): | |
| self.client = client | |
| n = None | |
| self.subj_prompt = f""" | |
| extract {n} collocations describing key concepts, keywords, named entities from the provided source | |
| """ | |
| self.obj_prompt = """ | |
| extract 5-10 most representative collocations from the provided source that are related to the provided concept | |
| """ | |
| self.pred_prompt = """ | |
| define the relationship between two words: generate a verb or a phrase decribing a relationship between two entities; return a predicate for a knowledge graph triplet | |
| """ | |
| def _get_subj(self, text, n=10): | |
| """ | |
| Extract entities from the text: | |
| - named entities | |
| - kewords | |
| - concepts | |
| :text: input text (str) | |
| :n: the number of genrated entities (int) | |
| :return: {core_concepts: list of extracted keywords (subjects that will form triplets)} (dict) | |
| """ | |
| import ast | |
| # Generate keywords from the given text using LLM | |
| core_concepts = self.client.chat.completions.create(messages= | |
| [ | |
| { | |
| "role": "system", | |
| "content": self.subj_prompt | |
| }, | |
| { | |
| "role": "user", | |
| "content": text | |
| }, | |
| ], | |
| response_format= | |
| { | |
| "type": "json", | |
| "value": | |
| { | |
| "properties": | |
| { | |
| "core_concepts": | |
| { | |
| "type": "array", | |
| "items": | |
| { | |
| "type": "string" | |
| } | |
| }, | |
| } | |
| } | |
| }, | |
| stream=False, | |
| max_tokens=1024, | |
| temperature=0.5, | |
| top_p=0.1 | |
| ).choices[0].get('message')['content'] | |
| return ast.literal_eval(core_concepts) | |
| def __extract_relations(self, word, text): | |
| import ast | |
| """ | |
| Extract relation for the provided concepts (subjects) based on the information from the text: | |
| - collocations | |
| :text: input text (str) | |
| :concepts: the list of kewords and other key concepts extracted with aggile._get_subj (dict) | |
| :return: {related_concepts: list of related words and collocations (objects that will form triplets)} (dict) | |
| """ | |
| related_concepts = self.client.chat.completions.create(messages= | |
| [ | |
| { | |
| "role": "system", | |
| "content": self.obj_prompt | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"concept = {word}, source = {text}" | |
| }, | |
| ], | |
| response_format= | |
| { | |
| "type": "json", | |
| "value": | |
| { | |
| "properties": | |
| { | |
| "related_concepts": | |
| { | |
| "type": "array", | |
| "items": | |
| { | |
| "type": "string" | |
| } | |
| }, | |
| } | |
| } | |
| }, | |
| stream=False, | |
| max_tokens=512, | |
| temperature=0.5, | |
| top_p=0.1 | |
| ).choices[0].get('message')['content'] | |
| return ast.literal_eval(related_concepts) | |
| def _get_obj(self, text): | |
| """ | |
| Execute the extraction of related concepts for the list of keywords: | |
| - generate list of objects for each object in the dictionarytract relation for the provided concepts (subjects) based on the information from the text: | |
| :text: input text (str) | |
| :concepts: the list of keywords and other key concepts extracted with aggile._get_subj (dict) | |
| :return: {related_concepts: list of related words and collocations (objects that will form triplets)} (dict) | |
| """ | |
| # Generate list of subjects | |
| core_concepts = self._get_subj(text, n=10) | |
| # Get object for each subject | |
| relations = {word: self.__extract_relations(word, text) for word in core_concepts['core_concepts']} | |
| return relations | |
| def __generate_predicates(self, subj, obj): | |
| import ast | |
| """ | |
| Generate predicates between objects and subjects | |
| :subj: one generated subject from core_concepts (str) | |
| :obj: one generated object from relations (str) | |
| :text: input text (str) | |
| :return: one relevant predicate to form triplets (str) | |
| """ | |
| predicate = self.client.chat.completions.create(messages= | |
| [ | |
| { | |
| "role": "system", | |
| "content": self.pred_prompt | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"what is the relationship between {subj} and {obj}? return a predicate only" | |
| }, | |
| ], | |
| response_format= | |
| { | |
| "type": "json", | |
| "value": | |
| { | |
| "properties": | |
| { | |
| "predicate": | |
| { | |
| "type": "string" | |
| }, | |
| } | |
| } | |
| }, | |
| stream=False, | |
| max_tokens=512, | |
| temperature=0.5, | |
| top_p=0.1 | |
| ).choices[0].get('message')['content'] | |
| return ast.literal_eval(predicate)['predicate'] # Return predicate only, not the whole dictionary | |
| def form_triples(self, text): | |
| """ | |
| :text: input text (str) if from_string=True | |
| """ | |
| # Generate objects from text | |
| relations = self._get_obj(text) | |
| # Placeholder for triplets | |
| triplets = dict() | |
| # Form triplets for each subject | |
| for subj in relations: | |
| # Placeholder for the current subject | |
| triplets[subj] = list() | |
| # For each object generated for this subject: | |
| for obj in relations[subj]['related_concepts']: | |
| # Create placeholder with the triplet structure "subject-predicate-object" | |
| temp = {'subject': subj, 'predicate': '', 'object': ''} | |
| # Save the object to the triplet | |
| temp['object'] = obj | |
| # Generate predicate between the current object and the current subject | |
| temp['predicate'] = self.__generate_predicates(subj, obj) | |
| # Hallucincation check: if object and subjects are the same entities, do not append them to the list of triplets | |
| if temp['subject'] != temp['object']: | |
| # Otherwise, append the triplet | |
| triplets[subj].append(temp) | |
| return triplets | |
| class Graph: | |
| def __init__(self, triplets): | |
| self.triplets = triplets | |
| def build_graph(self): | |
| import plotly.graph_objects as go | |
| import networkx as nx | |
| from collections import Counter | |
| import random | |
| # Prepare nodes and edges | |
| nodes = set() | |
| edges = [] | |
| # Extract noded and edges from the set of triplets | |
| for key, values in self.triplets.items(): | |
| for rel in values: | |
| nodes.add(rel['subject']) | |
| nodes.add(rel['object']) | |
| edges.append((rel['subject'], rel['object'], rel['predicate'])) | |
| # Create a networkx graph | |
| G = nx.Graph() | |
| # Add nodes and edges to the graph | |
| for edge in edges: | |
| G.add_edge(edge[0], edge[1], label=edge[2]) | |
| # Generate positions for nodes using force-directed layout with more space | |
| pos = nx.spring_layout(G, seed=42) # Increasing k for more spacing | |
| # Extract node and edge data for Plotly | |
| node_x = [pos[node][0] for node in G.nodes()] | |
| node_y = [pos[node][1] for node in G.nodes()] | |
| node_labels = list(G.nodes()) | |
| # Count connections | |
| node_degrees = Counter([node for edge in edges for node in edge[:2]]) | |
| # Assign distinct colors for each predicate (use a set to avoid duplicates) | |
| unique_predicates = list(set([edge[2] for edge in edges])) | |
| predicate_colors = {predicate: f'rgba({random.randint(0,255)},{random.randint(0,255)},{random.randint(0,255)},1)' | |
| for predicate in unique_predicates} | |
| # Plotly data for edges | |
| edge_x = [] | |
| edge_y = [] | |
| for edge in edges: | |
| x0, y0 = pos[edge[0]] | |
| x1, y1 = pos[edge[1]] | |
| edge_x += [x0, x1, None] | |
| edge_y += [y0, y1, None] | |
| # Create the figure | |
| fig = go.Figure() | |
| # Add edges | |
| fig.add_trace(go.Scatter( | |
| x=edge_x, y=edge_y, | |
| line=dict(width=0.5, color='#888'), | |
| hoverinfo='text', | |
| mode='lines' | |
| )) | |
| # Add nodes with uniform size and labels | |
| fig.add_trace(go.Scatter( | |
| x=node_x, y=node_y, | |
| mode='markers+text', | |
| marker=dict( | |
| size=25, # Uniform node size for all nodes | |
| color=[node_degrees[node] for node in node_labels], | |
| #colorscale='Viridis', | |
| colorbar=dict(title='Connections') | |
| ), | |
| text=node_labels, | |
| hoverinfo='text', | |
| textposition='top center', | |
| textfont=dict(size=13, weight="bold") | |
| )) | |
| # Add predicate labels near the nodes with black text | |
| for edge in edges: | |
| x0, y0 = pos[edge[0]] | |
| x1, y1 = pos[edge[1]] | |
| predicate_label = edge[2] | |
| # Calculate the midpoint of the edge and add small offsets to create spacing | |
| mid_x = (x0 + x1) / 2 | |
| mid_y = (y0 + y1) / 2 | |
| # Add the label near the midpoint of the edge with black text | |
| fig.add_trace(go.Scatter( | |
| x=[mid_x], y=[mid_y], | |
| mode='text', | |
| text=[predicate_label], | |
| textposition='middle center', | |
| showlegend=False, | |
| textfont=dict(size=10) | |
| )) | |
| # Update layout | |
| fig.update_layout( | |
| showlegend=False, | |
| margin=dict(l=0, r=0, t=0, b=0), | |
| xaxis=dict(showgrid=False, zeroline=False), | |
| yaxis=dict(showgrid=False, zeroline=False), | |
| title="Force-Directed Graph with Predicate Labels on Nodes" | |
| ) | |
| # Save the figure as an HTML file | |
| #fig.write_html("graph_with_predicates.html") | |
| return fig |