| | import gradio as gr |
| | import networkx as nx |
| | import matplotlib.pyplot as plt |
| | import io |
| | from PIL import Image |
| | import json |
| | from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline |
| | import os |
| | import re |
| | from collections import defaultdict |
| | import torch |
| |
|
| | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" |
| |
|
| | |
| | print("Loading NER model...") |
| | ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") |
| | ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") |
| | ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") |
| |
|
| | |
| | print("Loading REBEL model...") |
| | rebel_pipeline = None |
| |
|
| | def load_rebel_model(): |
| | global rebel_pipeline |
| | |
| | models_to_try = [ |
| | "Babelscape/rebel-small", |
| | "Babelscape/rebel-base" |
| | ] |
| | |
| | for model_name in models_to_try: |
| | try: |
| | print(f"Trying to load {model_name}...") |
| | rebel_pipeline = pipeline( |
| | "text2text-generation", |
| | model=model_name, |
| | tokenizer=model_name, |
| | device_map="auto" if torch.cuda.is_available() else None, |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
| | ) |
| | print(f"✅ Successfully loaded {model_name}") |
| | return True |
| | except Exception as e: |
| | print(f"❌ Failed to load {model_name}: {str(e)[:100]}...") |
| | continue |
| | |
| | print("❌ Could not load any REBEL model") |
| | return False |
| |
|
| | |
| | rebel_loaded = load_rebel_model() |
| |
|
| | def extract_entities(text): |
| | """Extract named entities using BERT-NER""" |
| | entities = ner_pipeline(text) |
| | processed_entities = [] |
| | |
| | for entity in entities: |
| | processed_entities.append({ |
| | "text": entity["word"].replace("##", ""), |
| | "label": entity["entity_group"], |
| | "start": int(entity["start"]), |
| | "end": int(entity["end"]), |
| | "confidence": round(float(entity["score"]), 3) |
| | }) |
| | |
| | return processed_entities |
| |
|
| | def parse_rebel_output(generated_text): |
| | """Enhanced REBEL output parser with multiple pattern matching""" |
| | triplets = [] |
| | |
| | print(f"Raw REBEL output: {generated_text}") |
| | |
| | |
| | generated_text = generated_text.strip() |
| | |
| | |
| | parsing_patterns = [ |
| | |
| | r'<triplet>\s*([^<]+?)\s*<subj>\s*([^<]+?)\s*<obj>\s*([^<]+?)(?:\s*<\/triplet>|\s*</s>|\s*$)', |
| | |
| | r'<triplet>\s*([^<]+?)\s*<subj>\s*([^<]+?)\s*<obj>\s*([^<]+?)(?=\s*<triplet>|\s*$)', |
| | |
| | r'([^<\n]+?)\s*<subj>\s*([^<\n]+?)\s*<obj>\s*([^<\n]+?)(?:\n|$)', |
| | ] |
| | |
| | for i, pattern in enumerate(parsing_patterns): |
| | matches = re.findall(pattern, generated_text, re.IGNORECASE | re.MULTILINE) |
| | print(f"Pattern {i+1} found {len(matches)} matches") |
| | |
| | for match in matches: |
| | if len(match) >= 3: |
| | relation = clean_text(match[0]) |
| | subject = clean_text(match[1]) |
| | obj = clean_text(match[2]) |
| | |
| | if validate_triplet(subject, relation, obj): |
| | triplets.append({ |
| | "subject": subject, |
| | "relation": format_relation(relation), |
| | "object": obj, |
| | "confidence": 0.9, |
| | "source": "REBEL" |
| | }) |
| | |
| | |
| | if triplets: |
| | break |
| | |
| | |
| | if not triplets: |
| | print("Trying fallback parsing...") |
| | |
| | fallback_pattern = r'([A-Za-z][A-Za-z\s]+?)\s+([a-z_]+)\s+([A-Za-z][A-Za-z\s]+?)(?:\.|$|\n)' |
| | matches = re.findall(fallback_pattern, generated_text) |
| | |
| | for match in matches: |
| | subject = clean_text(match[0]) |
| | relation = clean_text(match[1]) |
| | obj = clean_text(match[2]) |
| | |
| | if validate_triplet(subject, relation, obj): |
| | triplets.append({ |
| | "subject": subject, |
| | "relation": format_relation(relation), |
| | "object": obj, |
| | "confidence": 0.7, |
| | "source": "REBEL-fallback" |
| | }) |
| | |
| | return triplets |
| |
|
| | def clean_text(text): |
| | """Clean extracted text""" |
| | if not text: |
| | return "" |
| | |
| | |
| | text = re.sub(r'<[^>]+>', '', text) |
| | |
| | text = ' '.join(text.split()) |
| | |
| | text = text.strip('.,!?;: ') |
| | |
| | return text |
| |
|
| | def format_relation(relation): |
| | """Format relation text for better readability""" |
| | if not relation: |
| | return "related_to" |
| | |
| | |
| | relation_map = { |
| | 'ceo': 'CEO_of', |
| | 'founder': 'founded_by', |
| | 'president': 'president_of', |
| | 'member': 'member_of', |
| | 'location': 'located_in', |
| | 'country': 'country_of', |
| | 'spouse': 'married_to', |
| | 'parent': 'parent_of', |
| | 'child': 'child_of', |
| | 'sibling': 'sibling_of', |
| | 'employee': 'works_for', |
| | 'owner': 'owns', |
| | 'creator': 'created_by' |
| | } |
| | |
| | relation_lower = relation.lower().strip() |
| | |
| | |
| | if relation_lower in relation_map: |
| | return relation_map[relation_lower] |
| | |
| | |
| | formatted = relation.replace('_', ' ').replace('-', ' ') |
| | formatted = ' '.join(word.capitalize() for word in formatted.split()) |
| | |
| | return formatted |
| |
|
| | def validate_triplet(subject, relation, object_text): |
| | """Validate if a triplet makes sense""" |
| | if not subject or not relation or not object_text: |
| | return False |
| | |
| | |
| | if len(subject) < 2 or len(object_text) < 2: |
| | return False |
| | |
| | |
| | if subject.lower() == object_text.lower(): |
| | return False |
| | |
| | |
| | if len(subject) > 50 or len(object_text) > 50 or len(relation) > 30: |
| | return False |
| | |
| | |
| | if not re.search(r'[A-Za-z]', subject) or not re.search(r'[A-Za-z]', object_text): |
| | return False |
| | |
| | return True |
| |
|
| | def extract_relations_rebel(text): |
| | """Extract relations using REBEL model with optimized parameters""" |
| | if not rebel_pipeline: |
| | return [] |
| | |
| | try: |
| | |
| | text = preprocess_text_for_rebel(text) |
| | |
| | |
| | generated_tokens = rebel_pipeline( |
| | text, |
| | max_length=512, |
| | min_length=10, |
| | num_beams=3, |
| | do_sample=False, |
| | early_stopping=True, |
| | return_full_text=False, |
| | clean_up_tokenization_spaces=True |
| | ) |
| | |
| | generated_text = generated_tokens[0]["generated_text"] |
| | |
| | |
| | triplets = parse_rebel_output(generated_text) |
| | |
| | print(f"REBEL extracted {len(triplets)} relations") |
| | return triplets |
| | |
| | except Exception as e: |
| | print(f"REBEL extraction error: {e}") |
| | return [] |
| |
|
| | def preprocess_text_for_rebel(text): |
| | """Preprocess text to improve REBEL performance""" |
| | |
| | sentences = re.split(r'[.!?]+', text) |
| | |
| | |
| | if len(' '.join(sentences)) > 200: |
| | text = '. '.join(sentences[:3]) + '.' |
| | |
| | |
| | text = re.sub(r'\s+', ' ', text) |
| | text = text.strip() |
| | |
| | return text |
| |
|
| | def create_simple_fallback_relations(entities): |
| | """Create simple relations when REBEL fails""" |
| | relations = [] |
| | |
| | if len(entities) < 2: |
| | return relations |
| | |
| | |
| | for i, ent1 in enumerate(entities[:-1]): |
| | ent2 = entities[i + 1] |
| | |
| | relation_type = determine_relation_by_type(ent1["label"], ent2["label"]) |
| | |
| | relations.append({ |
| | "subject": ent1["text"], |
| | "relation": relation_type, |
| | "object": ent2["text"], |
| | "confidence": 0.5, |
| | "source": "type-based" |
| | }) |
| | |
| | return relations[:5] |
| |
|
| | def determine_relation_by_type(type1, type2): |
| | """Determine relation type based on entity types""" |
| | type_relations = { |
| | ("PER", "ORG"): "works_for", |
| | ("ORG", "PER"): "employs", |
| | ("PER", "LOC"): "lives_in", |
| | ("ORG", "LOC"): "located_in", |
| | ("ORG", "ORG"): "partners_with", |
| | ("PER", "PER"): "knows", |
| | ("LOC", "LOC"): "near", |
| | ("MISC", "ORG"): "owned_by", |
| | ("MISC", "PER"): "used_by" |
| | } |
| | |
| | return type_relations.get((type1, type2), "related_to") |
| |
|
| | def extract_relations(text): |
| | """Main relation extraction function""" |
| | try: |
| | entities = extract_entities(text) |
| | print(f"Found {len(entities)} entities") |
| | |
| | if rebel_loaded: |
| | |
| | relations = extract_relations_rebel(text) |
| | if relations: |
| | return relations |
| | else: |
| | print("REBEL didn't return relations, using fallback...") |
| | |
| | |
| | relations = create_simple_fallback_relations(entities) |
| | return relations |
| | |
| | except Exception as e: |
| | print(f"Relation extraction error: {e}") |
| | return [] |
| |
|
| | def create_knowledge_graph(triplets): |
| | if not triplets: |
| | return None, "No relations found. Try entering text with clearer relationships." |
| | |
| | G = nx.DiGraph() |
| | |
| | |
| | edge_labels = {} |
| | for triplet in triplets: |
| | subject = triplet["subject"] |
| | obj = triplet["object"] |
| | relation = triplet["relation"] |
| | |
| | if subject and obj and subject != obj: |
| | G.add_edge(subject, obj) |
| | edge_labels[(subject, obj)] = relation |
| | |
| | if len(G.nodes()) == 0: |
| | return None, "No valid graph nodes created." |
| | |
| | |
| | plt.figure(figsize=(14, 10)) |
| | plt.clf() |
| | |
| | |
| | if len(G.nodes()) <= 6: |
| | pos = nx.spring_layout(G, k=3, iterations=100, seed=42) |
| | else: |
| | pos = nx.spring_layout(G, k=2, iterations=50, seed=42) |
| | |
| | |
| | nx.draw_networkx_nodes(G, pos, |
| | node_color='lightblue', |
| | node_size=4000, |
| | alpha=0.8, |
| | linewidths=2, |
| | edgecolors='darkblue') |
| | |
| | |
| | nx.draw_networkx_edges(G, pos, |
| | edge_color='gray', |
| | arrows=True, |
| | arrowsize=25, |
| | alpha=0.6, |
| | width=2, |
| | connectionstyle="arc3,rad=0.1") |
| | |
| | |
| | nx.draw_networkx_labels(G, pos, font_size=10, font_weight='bold') |
| | nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=8, font_color='red', font_weight='bold') |
| | |
| | plt.title("Knowledge Graph (REBEL + Fallback)", size=16, weight='bold') |
| | plt.axis('off') |
| | plt.tight_layout() |
| | |
| | |
| | img_buffer = io.BytesIO() |
| | plt.savefig(img_buffer, format='png', dpi=200, bbox_inches='tight') |
| | img_buffer.seek(0) |
| | |
| | img = Image.open(img_buffer) |
| | plt.close() |
| | |
| | return img, f"Graph created with {len(G.nodes())} nodes and {len(G.edges())} edges." |
| |
|
| | def format_entities_for_display(entities): |
| | return [(entity["text"], entity["label"]) for entity in entities] |
| |
|
| | def process_news_text(text): |
| | if not text.strip(): |
| | return [], "No text provided", None, "Please enter some text to analyze." |
| | |
| | try: |
| | entities = extract_entities(text) |
| | entity_display = format_entities_for_display(entities) |
| | triplets = extract_relations(text) |
| | graph_img, graph_status = create_knowledge_graph(triplets) |
| | |
| | results = { |
| | "entities_found": len(entities), |
| | "relations_found": len(triplets), |
| | "rebel_model_loaded": rebel_loaded, |
| | "entities": entities, |
| | "triplets": triplets |
| | } |
| | |
| | status = f"✅ Found {len(entities)} entities, {len(triplets)} relations" |
| | if rebel_loaded: |
| | status += " (REBEL enabled)" |
| | else: |
| | status += " (REBEL not available, using fallback)" |
| | |
| | return entity_display, json.dumps(results, indent=2), graph_img, status |
| | |
| | except Exception as e: |
| | error_msg = f"❌ Error: {str(e)}" |
| | return [], "{}", None, error_msg |
| |
|
| | |
| | examples = [ |
| | "AI is reshaping corporate fortunes: while Nvidia, Microsoft, and Google surge with AI-driven gains, Apple and Tesla have lagged, revealing a growing split among the 'Magnificent Seven'—with investors watching to see if laggards can catch up or if the group will fracture entirely.", |
| | "Elon Musk is steering Tesla toward becoming an AI robotics powerhouse, integrating his startup xAI into Tesla vehicles—he’s also asked shareholders to approve Tesla funding xAI, marking a bold shift away from traditional EV focus toward autonomous driving, humanoid robots, and supercomputing infrastructure.", |
| | "OpenAI, valued at $300 billion with over 500 million weekly users, is under pressure from rivals like Meta, Google, Amazon, and xAI—despite strong uptake, it’s battling talent poaching, delayed model launches due to safety reviews, and legal disputes with Microsoft over partnership terms and AGI control.", |
| | "Microsoft’s AI chief Mustafa Suleyman stresses a pragmatic, human-centered AI strategy: his focus is on safe, real-world tools like Copilot and Bing, not speculative AGI; he estimates AGI is at least a decade away, reflecting Microsoft’s measured balance with its OpenAI partnership.", |
| | "Jony Ive, the legendary designer behind the iPhone, is joining OpenAI after a $6.5 billion acquisition of his hardware startup io; the deal sets OpenAI on course to develop consumer AI devices, signaling a major push beyond software into hardware innovation." |
| | ] |
| |
|
| | |
| | with gr.Blocks(title="REBEL Knowledge Graph Extractor", theme=gr.themes.Soft()) as demo: |
| | gr.HTML(f""" |
| | <div style="text-align: center; margin: 20px;"> |
| | <h1>🤖 News Knowledge Graph Extractor</h1> |
| | <p>Optimized for REBEL model relation extraction</p> |
| | <p><strong>Status:</strong> REBEL Model {'✅ Loaded' if rebel_loaded else '❌ Not Available'}</p> |
| | </div> |
| | """) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | input_text = gr.Textbox( |
| | label="Input Text", |
| | placeholder="Enter your text here...", |
| | lines=6 |
| | ) |
| | |
| | process_btn = gr.Button("Extract Relations", variant="primary") |
| | |
| | status_output = gr.Textbox( |
| | label="Status", |
| | interactive=False, |
| | max_lines=2 |
| | ) |
| | |
| | with gr.Column(): |
| | entity_output = gr.HighlightedText( |
| | label="Named Entities", |
| | color_map={ |
| | "PER": "lightblue", |
| | "ORG": "lightgreen", |
| | "LOC": "orange", |
| | "MISC": "lightpink" |
| | } |
| | ) |
| | |
| | results_output = gr.JSON(label="Detailed Results") |
| | |
| | with gr.Row(): |
| | graph_output = gr.Image(label="Knowledge Graph", height=600) |
| | |
| | with gr.Row(): |
| | gr.Examples(examples=examples, inputs=[input_text]) |
| | |
| | |
| | process_btn.click( |
| | fn=process_news_text, |
| | inputs=[input_text], |
| | outputs=[entity_output, results_output, graph_output, status_output] |
| | ) |
| | |
| | input_text.submit( |
| | fn=process_news_text, |
| | inputs=[input_text], |
| | outputs=[entity_output, results_output, graph_output, status_output] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(server_port=7860, share=False) |