Spaces:
Running on Zero
Running on Zero
Disabled phi3 model
Browse files- app.py +29 -19
- llm_graph.py +6 -2
app.py
CHANGED
|
@@ -18,6 +18,14 @@ import networkx as nx
|
|
| 18 |
|
| 19 |
from llm_graph import LLMGraph, MODEL_LIST
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
from pyvis.network import Network
|
| 22 |
from spacy import displacy
|
| 23 |
from spacy.tokens import Span
|
|
@@ -107,7 +115,7 @@ def extract_kg(text="", model_name=MODEL_LIST[0], model=None, graph_file=""):
|
|
| 107 |
|
| 108 |
try:
|
| 109 |
start_time = time.time()
|
| 110 |
-
if model_name
|
| 111 |
# Load the graph directly from cache
|
| 112 |
logging.info(f"Loading graph from cache: {graph_file}")
|
| 113 |
G = nx.read_graphml(graph_file)
|
|
@@ -116,7 +124,7 @@ def extract_kg(text="", model_name=MODEL_LIST[0], model=None, graph_file=""):
|
|
| 116 |
result = nx.node_link_data(G, edges="edges")
|
| 117 |
else:
|
| 118 |
result = model.extract(text, model_name, graph_file)
|
| 119 |
-
|
| 120 |
end_time = time.time()
|
| 121 |
duration = end_time - start_time
|
| 122 |
logging.info(f"Response time: {duration:.4f} seconds")
|
|
@@ -212,7 +220,7 @@ def create_graph(json_data, model_name=MODEL_LIST[0], graph_file=""):
|
|
| 212 |
Create interactive knowledge graph using Pyvis.
|
| 213 |
"""
|
| 214 |
|
| 215 |
-
if model_name
|
| 216 |
G = nx.Graph()
|
| 217 |
|
| 218 |
# Add nodes with tooltips and error handling for missing keys
|
|
@@ -222,7 +230,7 @@ def create_graph(json_data, model_name=MODEL_LIST[0], graph_file=""):
|
|
| 222 |
|
| 223 |
# Get detailed type with fallback
|
| 224 |
detailed_type = node.get("detailed_type", type)
|
| 225 |
-
|
| 226 |
# Use node ID and type info for the tooltip
|
| 227 |
G.add_node(node['id'], title=f"{type}: {detailed_type}")
|
| 228 |
|
|
@@ -247,7 +255,7 @@ def create_graph(json_data, model_name=MODEL_LIST[0], graph_file=""):
|
|
| 247 |
|
| 248 |
# Configure network display
|
| 249 |
network.from_nx(G)
|
| 250 |
-
if model_name
|
| 251 |
network.barnes_hut(
|
| 252 |
gravity=-3000,
|
| 253 |
central_gravity=0.3,
|
|
@@ -339,12 +347,12 @@ def process_and_visualize(text, model_name, progress=gr.Progress()):
|
|
| 339 |
is_first_example = text == EXAMPLES[0][0]
|
| 340 |
|
| 341 |
# Try to load from cache if it's the first example
|
| 342 |
-
if is_first_example and model_name
|
| 343 |
try:
|
| 344 |
progress(0.3, desc="Loading from cache...")
|
| 345 |
with open(EXAMPLE_CACHE_FILE, 'rb') as f:
|
| 346 |
cached_data = pickle.load(f)
|
| 347 |
-
|
| 348 |
progress(1.0, desc="Loaded from cache!")
|
| 349 |
return cached_data["graph_html"], cached_data["entities_viz"], cached_data["json_data"], cached_data["stats"]
|
| 350 |
except Exception as e:
|
|
@@ -379,20 +387,20 @@ def process_and_visualize(text, model_name, progress=gr.Progress()):
|
|
| 379 |
json_data = extract_kg(text, model_name, model, graph_file)
|
| 380 |
|
| 381 |
progress(0.5, desc="Creating entity visualization...")
|
| 382 |
-
if model_name
|
| 383 |
entities_viz = create_custom_entity_viz(json_data, text, type_col="type")
|
| 384 |
else:
|
| 385 |
entities_viz = create_custom_entity_viz(json_data, text, type_col="entity_type")
|
| 386 |
-
|
| 387 |
progress(0.8, desc="Building knowledge graph...")
|
| 388 |
graph_html = create_graph(json_data, model_name, graph_file)
|
| 389 |
-
|
| 390 |
node_count = len(json_data["nodes"])
|
| 391 |
edge_count = len(json_data["edges"])
|
| 392 |
stats = f"📊 Extracted {node_count} entities and {edge_count} relationships"
|
| 393 |
-
|
| 394 |
# Save to cache if it's the first example
|
| 395 |
-
if is_first_example and model_name
|
| 396 |
try:
|
| 397 |
cached_data = {
|
| 398 |
"graph_html": graph_html,
|
|
@@ -421,6 +429,7 @@ EXAMPLES = [
|
|
| 421 |
def generate_first_example():
|
| 422 |
"""
|
| 423 |
Generate cache for the first example if it doesn't exist when the app starts.
|
|
|
|
| 424 |
"""
|
| 425 |
|
| 426 |
if not os.path.exists(EXAMPLE_CACHE_FILE):
|
|
@@ -428,21 +437,22 @@ def generate_first_example():
|
|
| 428 |
|
| 429 |
try:
|
| 430 |
text = EXAMPLES[0][0]
|
| 431 |
-
|
|
|
|
| 432 |
|
| 433 |
# Initialize the LLMGraph model
|
| 434 |
model = LLMGraph(working_dir=WORKING_DIR)
|
| 435 |
-
asyncio.run(model.initialize_rag())
|
| 436 |
|
| 437 |
# Extract data
|
| 438 |
json_data = extract_kg(text, model_name, model)
|
| 439 |
-
entities_viz = create_custom_entity_viz(json_data, text)
|
| 440 |
-
graph_html = create_graph(json_data)
|
| 441 |
-
|
| 442 |
node_count = len(json_data["nodes"])
|
| 443 |
edge_count = len(json_data["edges"])
|
| 444 |
stats = f"📊 Extracted {node_count} entities and {edge_count} relationships"
|
| 445 |
-
|
| 446 |
# Save to cache
|
| 447 |
cached_data = {
|
| 448 |
"graph_html": graph_html,
|
|
@@ -467,7 +477,7 @@ def generate_first_example():
|
|
| 467 |
return pickle.load(f)
|
| 468 |
except Exception as e:
|
| 469 |
logging.error(f"Error loading existing cache: {str(e)}")
|
| 470 |
-
|
| 471 |
return None
|
| 472 |
|
| 473 |
def create_ui():
|
|
|
|
| 18 |
|
| 19 |
from llm_graph import LLMGraph, MODEL_LIST
|
| 20 |
|
| 21 |
+
# Helper function to check if a model is the Phi-3 model (disabled or not)
|
| 22 |
+
def is_phi3_model(model_name):
|
| 23 |
+
return "Phi-3-mini-128k-instruct-graph" in model_name
|
| 24 |
+
|
| 25 |
+
# Helper function to check if a model is the OpenAI model
|
| 26 |
+
def is_openai_model(model_name):
|
| 27 |
+
return "OpenAI" in model_name or "GPT" in model_name
|
| 28 |
+
|
| 29 |
from pyvis.network import Network
|
| 30 |
from spacy import displacy
|
| 31 |
from spacy.tokens import Span
|
|
|
|
| 115 |
|
| 116 |
try:
|
| 117 |
start_time = time.time()
|
| 118 |
+
if is_openai_model(model_name) and os.path.exists(graph_file):
|
| 119 |
# Load the graph directly from cache
|
| 120 |
logging.info(f"Loading graph from cache: {graph_file}")
|
| 121 |
G = nx.read_graphml(graph_file)
|
|
|
|
| 124 |
result = nx.node_link_data(G, edges="edges")
|
| 125 |
else:
|
| 126 |
result = model.extract(text, model_name, graph_file)
|
| 127 |
+
|
| 128 |
end_time = time.time()
|
| 129 |
duration = end_time - start_time
|
| 130 |
logging.info(f"Response time: {duration:.4f} seconds")
|
|
|
|
| 220 |
Create interactive knowledge graph using Pyvis.
|
| 221 |
"""
|
| 222 |
|
| 223 |
+
if is_phi3_model(model_name):
|
| 224 |
G = nx.Graph()
|
| 225 |
|
| 226 |
# Add nodes with tooltips and error handling for missing keys
|
|
|
|
| 230 |
|
| 231 |
# Get detailed type with fallback
|
| 232 |
detailed_type = node.get("detailed_type", type)
|
| 233 |
+
|
| 234 |
# Use node ID and type info for the tooltip
|
| 235 |
G.add_node(node['id'], title=f"{type}: {detailed_type}")
|
| 236 |
|
|
|
|
| 255 |
|
| 256 |
# Configure network display
|
| 257 |
network.from_nx(G)
|
| 258 |
+
if is_phi3_model(model_name):
|
| 259 |
network.barnes_hut(
|
| 260 |
gravity=-3000,
|
| 261 |
central_gravity=0.3,
|
|
|
|
| 347 |
is_first_example = text == EXAMPLES[0][0]
|
| 348 |
|
| 349 |
# Try to load from cache if it's the first example
|
| 350 |
+
if is_first_example and is_phi3_model(model_name) and os.path.exists(EXAMPLE_CACHE_FILE):
|
| 351 |
try:
|
| 352 |
progress(0.3, desc="Loading from cache...")
|
| 353 |
with open(EXAMPLE_CACHE_FILE, 'rb') as f:
|
| 354 |
cached_data = pickle.load(f)
|
| 355 |
+
|
| 356 |
progress(1.0, desc="Loaded from cache!")
|
| 357 |
return cached_data["graph_html"], cached_data["entities_viz"], cached_data["json_data"], cached_data["stats"]
|
| 358 |
except Exception as e:
|
|
|
|
| 387 |
json_data = extract_kg(text, model_name, model, graph_file)
|
| 388 |
|
| 389 |
progress(0.5, desc="Creating entity visualization...")
|
| 390 |
+
if is_phi3_model(model_name):
|
| 391 |
entities_viz = create_custom_entity_viz(json_data, text, type_col="type")
|
| 392 |
else:
|
| 393 |
entities_viz = create_custom_entity_viz(json_data, text, type_col="entity_type")
|
| 394 |
+
|
| 395 |
progress(0.8, desc="Building knowledge graph...")
|
| 396 |
graph_html = create_graph(json_data, model_name, graph_file)
|
| 397 |
+
|
| 398 |
node_count = len(json_data["nodes"])
|
| 399 |
edge_count = len(json_data["edges"])
|
| 400 |
stats = f"📊 Extracted {node_count} entities and {edge_count} relationships"
|
| 401 |
+
|
| 402 |
# Save to cache if it's the first example
|
| 403 |
+
if is_first_example and is_phi3_model(model_name):
|
| 404 |
try:
|
| 405 |
cached_data = {
|
| 406 |
"graph_html": graph_html,
|
|
|
|
| 429 |
def generate_first_example():
|
| 430 |
"""
|
| 431 |
Generate cache for the first example if it doesn't exist when the app starts.
|
| 432 |
+
Note: Since the first model (Phi-3) is disabled, we use the second model (OpenAI) for caching.
|
| 433 |
"""
|
| 434 |
|
| 435 |
if not os.path.exists(EXAMPLE_CACHE_FILE):
|
|
|
|
| 437 |
|
| 438 |
try:
|
| 439 |
text = EXAMPLES[0][0]
|
| 440 |
+
# Use the second model (OpenAI) since the first is disabled
|
| 441 |
+
model_name = MODEL_LIST[1] if len(MODEL_LIST) > 1 else None
|
| 442 |
|
| 443 |
# Initialize the LLMGraph model
|
| 444 |
model = LLMGraph(working_dir=WORKING_DIR)
|
| 445 |
+
asyncio.run(model.initialize_rag())
|
| 446 |
|
| 447 |
# Extract data
|
| 448 |
json_data = extract_kg(text, model_name, model)
|
| 449 |
+
entities_viz = create_custom_entity_viz(json_data, text, type_col="entity_type")
|
| 450 |
+
graph_html = create_graph(json_data, model_name)
|
| 451 |
+
|
| 452 |
node_count = len(json_data["nodes"])
|
| 453 |
edge_count = len(json_data["edges"])
|
| 454 |
stats = f"📊 Extracted {node_count} entities and {edge_count} relationships"
|
| 455 |
+
|
| 456 |
# Save to cache
|
| 457 |
cached_data = {
|
| 458 |
"graph_html": graph_html,
|
|
|
|
| 477 |
return pickle.load(f)
|
| 478 |
except Exception as e:
|
| 479 |
logging.error(f"Error loading existing cache: {str(e)}")
|
| 480 |
+
|
| 481 |
return None
|
| 482 |
|
| 483 |
def create_ui():
|
llm_graph.py
CHANGED
|
@@ -28,7 +28,7 @@ AZURE_EMBEDDING_DEPLOYMENT = os.environ["AZURE_EMBEDDING_DEPLOYMENT"]
|
|
| 28 |
AZURE_EMBEDDING_API_VERSION = os.environ["AZURE_EMBEDDING_API_VERSION"]
|
| 29 |
|
| 30 |
MODEL_LIST = [
|
| 31 |
-
"EmergentMethods/Phi-3-mini-128k-instruct-graph",
|
| 32 |
"OpenAI/GPT-4.1-mini",
|
| 33 |
]
|
| 34 |
|
|
@@ -135,7 +135,11 @@ class LLMGraph:
|
|
| 135 |
Extract knowledge graph in structured format from text.
|
| 136 |
"""
|
| 137 |
|
| 138 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# Use Hugging Face Inference API with Phi-3-mini-128k-instruct-graph
|
| 140 |
messages = self._get_messages(text)
|
| 141 |
|
|
|
|
| 28 |
AZURE_EMBEDDING_API_VERSION = os.environ["AZURE_EMBEDDING_API_VERSION"]
|
| 29 |
|
| 30 |
MODEL_LIST = [
|
| 31 |
+
"EmergentMethods/Phi-3-mini-128k-instruct-graph (Disabled)",
|
| 32 |
"OpenAI/GPT-4.1-mini",
|
| 33 |
]
|
| 34 |
|
|
|
|
| 135 |
Extract knowledge graph in structured format from text.
|
| 136 |
"""
|
| 137 |
|
| 138 |
+
# Check if the model is the disabled Phi-3 model
|
| 139 |
+
if "Phi-3-mini-128k-instruct-graph" in model_name and "Disabled" in model_name:
|
| 140 |
+
raise ValueError("This model is currently disabled. Please select another model.")
|
| 141 |
+
|
| 142 |
+
if model_name == MODEL_LIST[0] or "Phi-3-mini-128k-instruct-graph" in model_name:
|
| 143 |
# Use Hugging Face Inference API with Phi-3-mini-128k-instruct-graph
|
| 144 |
messages = self._get_messages(text)
|
| 145 |
|