vietexob commited on
Commit
1378401
·
1 Parent(s): a110d08

Disabled phi3 model

Browse files
Files changed (2) hide show
  1. app.py +29 -19
  2. llm_graph.py +6 -2
app.py CHANGED
@@ -18,6 +18,14 @@ import networkx as nx
18
 
19
  from llm_graph import LLMGraph, MODEL_LIST
20
 
 
 
 
 
 
 
 
 
21
  from pyvis.network import Network
22
  from spacy import displacy
23
  from spacy.tokens import Span
@@ -107,7 +115,7 @@ def extract_kg(text="", model_name=MODEL_LIST[0], model=None, graph_file=""):
107
 
108
  try:
109
  start_time = time.time()
110
- if model_name == MODEL_LIST[1] and os.path.exists(graph_file):
111
  # Load the graph directly from cache
112
  logging.info(f"Loading graph from cache: {graph_file}")
113
  G = nx.read_graphml(graph_file)
@@ -116,7 +124,7 @@ def extract_kg(text="", model_name=MODEL_LIST[0], model=None, graph_file=""):
116
  result = nx.node_link_data(G, edges="edges")
117
  else:
118
  result = model.extract(text, model_name, graph_file)
119
-
120
  end_time = time.time()
121
  duration = end_time - start_time
122
  logging.info(f"Response time: {duration:.4f} seconds")
@@ -212,7 +220,7 @@ def create_graph(json_data, model_name=MODEL_LIST[0], graph_file=""):
212
  Create interactive knowledge graph using Pyvis.
213
  """
214
 
215
- if model_name == MODEL_LIST[0]:
216
  G = nx.Graph()
217
 
218
  # Add nodes with tooltips and error handling for missing keys
@@ -222,7 +230,7 @@ def create_graph(json_data, model_name=MODEL_LIST[0], graph_file=""):
222
 
223
  # Get detailed type with fallback
224
  detailed_type = node.get("detailed_type", type)
225
-
226
  # Use node ID and type info for the tooltip
227
  G.add_node(node['id'], title=f"{type}: {detailed_type}")
228
 
@@ -247,7 +255,7 @@ def create_graph(json_data, model_name=MODEL_LIST[0], graph_file=""):
247
 
248
  # Configure network display
249
  network.from_nx(G)
250
- if model_name == MODEL_LIST[0]:
251
  network.barnes_hut(
252
  gravity=-3000,
253
  central_gravity=0.3,
@@ -339,12 +347,12 @@ def process_and_visualize(text, model_name, progress=gr.Progress()):
339
  is_first_example = text == EXAMPLES[0][0]
340
 
341
  # Try to load from cache if it's the first example
342
- if is_first_example and model_name == MODEL_LIST[0] and os.path.exists(EXAMPLE_CACHE_FILE):
343
  try:
344
  progress(0.3, desc="Loading from cache...")
345
  with open(EXAMPLE_CACHE_FILE, 'rb') as f:
346
  cached_data = pickle.load(f)
347
-
348
  progress(1.0, desc="Loaded from cache!")
349
  return cached_data["graph_html"], cached_data["entities_viz"], cached_data["json_data"], cached_data["stats"]
350
  except Exception as e:
@@ -379,20 +387,20 @@ def process_and_visualize(text, model_name, progress=gr.Progress()):
379
  json_data = extract_kg(text, model_name, model, graph_file)
380
 
381
  progress(0.5, desc="Creating entity visualization...")
382
- if model_name == MODEL_LIST[0]:
383
  entities_viz = create_custom_entity_viz(json_data, text, type_col="type")
384
  else:
385
  entities_viz = create_custom_entity_viz(json_data, text, type_col="entity_type")
386
-
387
  progress(0.8, desc="Building knowledge graph...")
388
  graph_html = create_graph(json_data, model_name, graph_file)
389
-
390
  node_count = len(json_data["nodes"])
391
  edge_count = len(json_data["edges"])
392
  stats = f"📊 Extracted {node_count} entities and {edge_count} relationships"
393
-
394
  # Save to cache if it's the first example
395
- if is_first_example and model_name == MODEL_LIST[0]:
396
  try:
397
  cached_data = {
398
  "graph_html": graph_html,
@@ -421,6 +429,7 @@ EXAMPLES = [
421
  def generate_first_example():
422
  """
423
  Generate cache for the first example if it doesn't exist when the app starts.
 
424
  """
425
 
426
  if not os.path.exists(EXAMPLE_CACHE_FILE):
@@ -428,21 +437,22 @@ def generate_first_example():
428
 
429
  try:
430
  text = EXAMPLES[0][0]
431
- model_name = MODEL_LIST[0] if MODEL_LIST else None
 
432
 
433
  # Initialize the LLMGraph model
434
  model = LLMGraph(working_dir=WORKING_DIR)
435
- asyncio.run(model.initialize_rag())
436
 
437
  # Extract data
438
  json_data = extract_kg(text, model_name, model)
439
- entities_viz = create_custom_entity_viz(json_data, text)
440
- graph_html = create_graph(json_data)
441
-
442
  node_count = len(json_data["nodes"])
443
  edge_count = len(json_data["edges"])
444
  stats = f"📊 Extracted {node_count} entities and {edge_count} relationships"
445
-
446
  # Save to cache
447
  cached_data = {
448
  "graph_html": graph_html,
@@ -467,7 +477,7 @@ def generate_first_example():
467
  return pickle.load(f)
468
  except Exception as e:
469
  logging.error(f"Error loading existing cache: {str(e)}")
470
-
471
  return None
472
 
473
  def create_ui():
 
18
 
19
  from llm_graph import LLMGraph, MODEL_LIST
20
 
21
+ # Helper function to check if a model is the Phi-3 model (disabled or not)
22
+ def is_phi3_model(model_name):
23
+ return "Phi-3-mini-128k-instruct-graph" in model_name
24
+
25
+ # Helper function to check if a model is the OpenAI model
26
+ def is_openai_model(model_name):
27
+ return "OpenAI" in model_name or "GPT" in model_name
28
+
29
  from pyvis.network import Network
30
  from spacy import displacy
31
  from spacy.tokens import Span
 
115
 
116
  try:
117
  start_time = time.time()
118
+ if is_openai_model(model_name) and os.path.exists(graph_file):
119
  # Load the graph directly from cache
120
  logging.info(f"Loading graph from cache: {graph_file}")
121
  G = nx.read_graphml(graph_file)
 
124
  result = nx.node_link_data(G, edges="edges")
125
  else:
126
  result = model.extract(text, model_name, graph_file)
127
+
128
  end_time = time.time()
129
  duration = end_time - start_time
130
  logging.info(f"Response time: {duration:.4f} seconds")
 
220
  Create interactive knowledge graph using Pyvis.
221
  """
222
 
223
+ if is_phi3_model(model_name):
224
  G = nx.Graph()
225
 
226
  # Add nodes with tooltips and error handling for missing keys
 
230
 
231
  # Get detailed type with fallback
232
  detailed_type = node.get("detailed_type", type)
233
+
234
  # Use node ID and type info for the tooltip
235
  G.add_node(node['id'], title=f"{type}: {detailed_type}")
236
 
 
255
 
256
  # Configure network display
257
  network.from_nx(G)
258
+ if is_phi3_model(model_name):
259
  network.barnes_hut(
260
  gravity=-3000,
261
  central_gravity=0.3,
 
347
  is_first_example = text == EXAMPLES[0][0]
348
 
349
  # Try to load from cache if it's the first example
350
+ if is_first_example and is_phi3_model(model_name) and os.path.exists(EXAMPLE_CACHE_FILE):
351
  try:
352
  progress(0.3, desc="Loading from cache...")
353
  with open(EXAMPLE_CACHE_FILE, 'rb') as f:
354
  cached_data = pickle.load(f)
355
+
356
  progress(1.0, desc="Loaded from cache!")
357
  return cached_data["graph_html"], cached_data["entities_viz"], cached_data["json_data"], cached_data["stats"]
358
  except Exception as e:
 
387
  json_data = extract_kg(text, model_name, model, graph_file)
388
 
389
  progress(0.5, desc="Creating entity visualization...")
390
+ if is_phi3_model(model_name):
391
  entities_viz = create_custom_entity_viz(json_data, text, type_col="type")
392
  else:
393
  entities_viz = create_custom_entity_viz(json_data, text, type_col="entity_type")
394
+
395
  progress(0.8, desc="Building knowledge graph...")
396
  graph_html = create_graph(json_data, model_name, graph_file)
397
+
398
  node_count = len(json_data["nodes"])
399
  edge_count = len(json_data["edges"])
400
  stats = f"📊 Extracted {node_count} entities and {edge_count} relationships"
401
+
402
  # Save to cache if it's the first example
403
+ if is_first_example and is_phi3_model(model_name):
404
  try:
405
  cached_data = {
406
  "graph_html": graph_html,
 
429
  def generate_first_example():
430
  """
431
  Generate cache for the first example if it doesn't exist when the app starts.
432
+ Note: Since the first model (Phi-3) is disabled, we use the second model (OpenAI) for caching.
433
  """
434
 
435
  if not os.path.exists(EXAMPLE_CACHE_FILE):
 
437
 
438
  try:
439
  text = EXAMPLES[0][0]
440
+ # Use the second model (OpenAI) since the first is disabled
441
+ model_name = MODEL_LIST[1] if len(MODEL_LIST) > 1 else None
442
 
443
  # Initialize the LLMGraph model
444
  model = LLMGraph(working_dir=WORKING_DIR)
445
+ asyncio.run(model.initialize_rag())
446
 
447
  # Extract data
448
  json_data = extract_kg(text, model_name, model)
449
+ entities_viz = create_custom_entity_viz(json_data, text, type_col="entity_type")
450
+ graph_html = create_graph(json_data, model_name)
451
+
452
  node_count = len(json_data["nodes"])
453
  edge_count = len(json_data["edges"])
454
  stats = f"📊 Extracted {node_count} entities and {edge_count} relationships"
455
+
456
  # Save to cache
457
  cached_data = {
458
  "graph_html": graph_html,
 
477
  return pickle.load(f)
478
  except Exception as e:
479
  logging.error(f"Error loading existing cache: {str(e)}")
480
+
481
  return None
482
 
483
  def create_ui():
llm_graph.py CHANGED
@@ -28,7 +28,7 @@ AZURE_EMBEDDING_DEPLOYMENT = os.environ["AZURE_EMBEDDING_DEPLOYMENT"]
28
  AZURE_EMBEDDING_API_VERSION = os.environ["AZURE_EMBEDDING_API_VERSION"]
29
 
30
  MODEL_LIST = [
31
- "EmergentMethods/Phi-3-mini-128k-instruct-graph",
32
  "OpenAI/GPT-4.1-mini",
33
  ]
34
 
@@ -135,7 +135,11 @@ class LLMGraph:
135
  Extract knowledge graph in structured format from text.
136
  """
137
 
138
- if model_name == MODEL_LIST[0]:
 
 
 
 
139
  # Use Hugging Face Inference API with Phi-3-mini-128k-instruct-graph
140
  messages = self._get_messages(text)
141
 
 
28
  AZURE_EMBEDDING_API_VERSION = os.environ["AZURE_EMBEDDING_API_VERSION"]
29
 
30
  MODEL_LIST = [
31
+ "EmergentMethods/Phi-3-mini-128k-instruct-graph (Disabled)",
32
  "OpenAI/GPT-4.1-mini",
33
  ]
34
 
 
135
  Extract knowledge graph in structured format from text.
136
  """
137
 
138
+ # Check if the model is the disabled Phi-3 model
139
+ if "Phi-3-mini-128k-instruct-graph" in model_name and "Disabled" in model_name:
140
+ raise ValueError("This model is currently disabled. Please select another model.")
141
+
142
+ if model_name == MODEL_LIST[0] or "Phi-3-mini-128k-instruct-graph" in model_name:
143
  # Use Hugging Face Inference API with Phi-3-mini-128k-instruct-graph
144
  messages = self._get_messages(text)
145