ner-kg / app.py
entropy25's picture
Update app.py
b46b6c4 verified
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
import io
from PIL import Image
import json
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import os
import re
from collections import defaultdict
import torch
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
# Load NER model
print("Loading NER model...")
ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
# Load REBEL model with better error handling and optimization
print("Loading REBEL model...")
rebel_pipeline = None
def load_rebel_model():
global rebel_pipeline
models_to_try = [
"Babelscape/rebel-small",
"Babelscape/rebel-base"
]
for model_name in models_to_try:
try:
print(f"Trying to load {model_name}...")
rebel_pipeline = pipeline(
"text2text-generation",
model=model_name,
tokenizer=model_name,
device_map="auto" if torch.cuda.is_available() else None,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
print(f"✅ Successfully loaded {model_name}")
return True
except Exception as e:
print(f"❌ Failed to load {model_name}: {str(e)[:100]}...")
continue
print("❌ Could not load any REBEL model")
return False
# Try to load REBEL model
rebel_loaded = load_rebel_model()
def extract_entities(text):
"""Extract named entities using BERT-NER"""
entities = ner_pipeline(text)
processed_entities = []
for entity in entities:
processed_entities.append({
"text": entity["word"].replace("##", ""), # Clean subword tokens
"label": entity["entity_group"],
"start": int(entity["start"]),
"end": int(entity["end"]),
"confidence": round(float(entity["score"]), 3)
})
return processed_entities
def parse_rebel_output(generated_text):
"""Enhanced REBEL output parser with multiple pattern matching"""
triplets = []
print(f"Raw REBEL output: {generated_text}")
# Clean the output
generated_text = generated_text.strip()
# Multiple parsing strategies
parsing_patterns = [
# Standard REBEL format
r'<triplet>\s*([^<]+?)\s*<subj>\s*([^<]+?)\s*<obj>\s*([^<]+?)(?:\s*<\/triplet>|\s*</s>|\s*$)',
# Alternative format without closing tags
r'<triplet>\s*([^<]+?)\s*<subj>\s*([^<]+?)\s*<obj>\s*([^<]+?)(?=\s*<triplet>|\s*$)',
# Simplified format
r'([^<\n]+?)\s*<subj>\s*([^<\n]+?)\s*<obj>\s*([^<\n]+?)(?:\n|$)',
]
for i, pattern in enumerate(parsing_patterns):
matches = re.findall(pattern, generated_text, re.IGNORECASE | re.MULTILINE)
print(f"Pattern {i+1} found {len(matches)} matches")
for match in matches:
if len(match) >= 3:
relation = clean_text(match[0])
subject = clean_text(match[1])
obj = clean_text(match[2])
if validate_triplet(subject, relation, obj):
triplets.append({
"subject": subject,
"relation": format_relation(relation),
"object": obj,
"confidence": 0.9,
"source": "REBEL"
})
# If we found valid triplets, use them
if triplets:
break
# Fallback: try to extract any meaningful patterns
if not triplets:
print("Trying fallback parsing...")
# Look for any pattern that might be relations
fallback_pattern = r'([A-Za-z][A-Za-z\s]+?)\s+([a-z_]+)\s+([A-Za-z][A-Za-z\s]+?)(?:\.|$|\n)'
matches = re.findall(fallback_pattern, generated_text)
for match in matches:
subject = clean_text(match[0])
relation = clean_text(match[1])
obj = clean_text(match[2])
if validate_triplet(subject, relation, obj):
triplets.append({
"subject": subject,
"relation": format_relation(relation),
"object": obj,
"confidence": 0.7,
"source": "REBEL-fallback"
})
return triplets
def clean_text(text):
"""Clean extracted text"""
if not text:
return ""
# Remove HTML tags and special tokens
text = re.sub(r'<[^>]+>', '', text)
# Remove extra whitespace
text = ' '.join(text.split())
# Remove leading/trailing punctuation
text = text.strip('.,!?;: ')
return text
def format_relation(relation):
"""Format relation text for better readability"""
if not relation:
return "related_to"
# Common relation mappings
relation_map = {
'ceo': 'CEO_of',
'founder': 'founded_by',
'president': 'president_of',
'member': 'member_of',
'location': 'located_in',
'country': 'country_of',
'spouse': 'married_to',
'parent': 'parent_of',
'child': 'child_of',
'sibling': 'sibling_of',
'employee': 'works_for',
'owner': 'owns',
'creator': 'created_by'
}
relation_lower = relation.lower().strip()
# Check direct mapping
if relation_lower in relation_map:
return relation_map[relation_lower]
# Format underscores and spaces
formatted = relation.replace('_', ' ').replace('-', ' ')
formatted = ' '.join(word.capitalize() for word in formatted.split())
return formatted
def validate_triplet(subject, relation, object_text):
"""Validate if a triplet makes sense"""
if not subject or not relation or not object_text:
return False
# Check minimum length
if len(subject) < 2 or len(object_text) < 2:
return False
# Check if subject and object are different
if subject.lower() == object_text.lower():
return False
# Check for reasonable length (not too long)
if len(subject) > 50 or len(object_text) > 50 or len(relation) > 30:
return False
# Check for non-alphabetic content
if not re.search(r'[A-Za-z]', subject) or not re.search(r'[A-Za-z]', object_text):
return False
return True
def extract_relations_rebel(text):
"""Extract relations using REBEL model with optimized parameters"""
if not rebel_pipeline:
return []
try:
# Preprocess text for better REBEL performance
text = preprocess_text_for_rebel(text)
# Generate with optimized parameters
generated_tokens = rebel_pipeline(
text,
max_length=512,
min_length=10,
num_beams=3,
do_sample=False,
early_stopping=True,
return_full_text=False,
clean_up_tokenization_spaces=True
)
generated_text = generated_tokens[0]["generated_text"]
# Parse the output
triplets = parse_rebel_output(generated_text)
print(f"REBEL extracted {len(triplets)} relations")
return triplets
except Exception as e:
print(f"REBEL extraction error: {e}")
return []
def preprocess_text_for_rebel(text):
"""Preprocess text to improve REBEL performance"""
# Limit length for better processing
sentences = re.split(r'[.!?]+', text)
# Take first 2-3 sentences if text is too long
if len(' '.join(sentences)) > 200:
text = '. '.join(sentences[:3]) + '.'
# Clean up the text
text = re.sub(r'\s+', ' ', text) # Remove extra whitespace
text = text.strip()
return text
def create_simple_fallback_relations(entities):
"""Create simple relations when REBEL fails"""
relations = []
if len(entities) < 2:
return relations
# Create relations based on entity types and proximity
for i, ent1 in enumerate(entities[:-1]):
ent2 = entities[i + 1]
relation_type = determine_relation_by_type(ent1["label"], ent2["label"])
relations.append({
"subject": ent1["text"],
"relation": relation_type,
"object": ent2["text"],
"confidence": 0.5,
"source": "type-based"
})
return relations[:5] # Limit to 5 relations
def determine_relation_by_type(type1, type2):
"""Determine relation type based on entity types"""
type_relations = {
("PER", "ORG"): "works_for",
("ORG", "PER"): "employs",
("PER", "LOC"): "lives_in",
("ORG", "LOC"): "located_in",
("ORG", "ORG"): "partners_with",
("PER", "PER"): "knows",
("LOC", "LOC"): "near",
("MISC", "ORG"): "owned_by",
("MISC", "PER"): "used_by"
}
return type_relations.get((type1, type2), "related_to")
def extract_relations(text):
"""Main relation extraction function"""
try:
entities = extract_entities(text)
print(f"Found {len(entities)} entities")
if rebel_loaded:
# Try REBEL first
relations = extract_relations_rebel(text)
if relations:
return relations
else:
print("REBEL didn't return relations, using fallback...")
# Fallback to simple relations
relations = create_simple_fallback_relations(entities)
return relations
except Exception as e:
print(f"Relation extraction error: {e}")
return []
def create_knowledge_graph(triplets):
if not triplets:
return None, "No relations found. Try entering text with clearer relationships."
G = nx.DiGraph()
# Add edges with labels
edge_labels = {}
for triplet in triplets:
subject = triplet["subject"]
obj = triplet["object"]
relation = triplet["relation"]
if subject and obj and subject != obj:
G.add_edge(subject, obj)
edge_labels[(subject, obj)] = relation
if len(G.nodes()) == 0:
return None, "No valid graph nodes created."
# Create visualization
plt.figure(figsize=(14, 10))
plt.clf()
# Layout
if len(G.nodes()) <= 6:
pos = nx.spring_layout(G, k=3, iterations=100, seed=42)
else:
pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
# Draw nodes
nx.draw_networkx_nodes(G, pos,
node_color='lightblue',
node_size=4000,
alpha=0.8,
linewidths=2,
edgecolors='darkblue')
# Draw edges
nx.draw_networkx_edges(G, pos,
edge_color='gray',
arrows=True,
arrowsize=25,
alpha=0.6,
width=2,
connectionstyle="arc3,rad=0.1")
# Draw labels
nx.draw_networkx_labels(G, pos, font_size=10, font_weight='bold')
nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=8, font_color='red', font_weight='bold')
plt.title("Knowledge Graph (REBEL + Fallback)", size=16, weight='bold')
plt.axis('off')
plt.tight_layout()
# Save to buffer
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='png', dpi=200, bbox_inches='tight')
img_buffer.seek(0)
img = Image.open(img_buffer)
plt.close()
return img, f"Graph created with {len(G.nodes())} nodes and {len(G.edges())} edges."
def format_entities_for_display(entities):
return [(entity["text"], entity["label"]) for entity in entities]
def process_news_text(text):
if not text.strip():
return [], "No text provided", None, "Please enter some text to analyze."
try:
entities = extract_entities(text)
entity_display = format_entities_for_display(entities)
triplets = extract_relations(text)
graph_img, graph_status = create_knowledge_graph(triplets)
results = {
"entities_found": len(entities),
"relations_found": len(triplets),
"rebel_model_loaded": rebel_loaded,
"entities": entities,
"triplets": triplets
}
status = f"✅ Found {len(entities)} entities, {len(triplets)} relations"
if rebel_loaded:
status += " (REBEL enabled)"
else:
status += " (REBEL not available, using fallback)"
return entity_display, json.dumps(results, indent=2), graph_img, status
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
return [], "{}", None, error_msg
# Examples
examples = [
"AI is reshaping corporate fortunes: while Nvidia, Microsoft, and Google surge with AI-driven gains, Apple and Tesla have lagged, revealing a growing split among the 'Magnificent Seven'—with investors watching to see if laggards can catch up or if the group will fracture entirely.",
"Elon Musk is steering Tesla toward becoming an AI robotics powerhouse, integrating his startup xAI into Tesla vehicles—he’s also asked shareholders to approve Tesla funding xAI, marking a bold shift away from traditional EV focus toward autonomous driving, humanoid robots, and supercomputing infrastructure.",
"OpenAI, valued at $300 billion with over 500 million weekly users, is under pressure from rivals like Meta, Google, Amazon, and xAI—despite strong uptake, it’s battling talent poaching, delayed model launches due to safety reviews, and legal disputes with Microsoft over partnership terms and AGI control.",
"Microsoft’s AI chief Mustafa Suleyman stresses a pragmatic, human-centered AI strategy: his focus is on safe, real-world tools like Copilot and Bing, not speculative AGI; he estimates AGI is at least a decade away, reflecting Microsoft’s measured balance with its OpenAI partnership.",
"Jony Ive, the legendary designer behind the iPhone, is joining OpenAI after a $6.5 billion acquisition of his hardware startup io; the deal sets OpenAI on course to develop consumer AI devices, signaling a major push beyond software into hardware innovation."
]
# Create Gradio interface
with gr.Blocks(title="REBEL Knowledge Graph Extractor", theme=gr.themes.Soft()) as demo:
gr.HTML(f"""
<div style="text-align: center; margin: 20px;">
<h1>🤖 News Knowledge Graph Extractor</h1>
<p>Optimized for REBEL model relation extraction</p>
<p><strong>Status:</strong> REBEL Model {'✅ Loaded' if rebel_loaded else '❌ Not Available'}</p>
</div>
""")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter your text here...",
lines=6
)
process_btn = gr.Button("Extract Relations", variant="primary")
status_output = gr.Textbox(
label="Status",
interactive=False,
max_lines=2
)
with gr.Column():
entity_output = gr.HighlightedText(
label="Named Entities",
color_map={
"PER": "lightblue",
"ORG": "lightgreen",
"LOC": "orange",
"MISC": "lightpink"
}
)
results_output = gr.JSON(label="Detailed Results")
with gr.Row():
graph_output = gr.Image(label="Knowledge Graph", height=600)
with gr.Row():
gr.Examples(examples=examples, inputs=[input_text])
# Event handlers
process_btn.click(
fn=process_news_text,
inputs=[input_text],
outputs=[entity_output, results_output, graph_output, status_output]
)
input_text.submit(
fn=process_news_text,
inputs=[input_text],
outputs=[entity_output, results_output, graph_output, status_output]
)
if __name__ == "__main__":
demo.launch(server_port=7860, share=False)