AIEcosystem commited on
Commit
0a622e5
·
verified ·
1 Parent(s): 5707233

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +99 -37
src/streamlit_app.py CHANGED
@@ -31,6 +31,7 @@ except ImportError:
31
 
32
  # --- Fixed Label Definitions and Mappings ---
33
  FIXED_LABELS = ["person", "country", "city", "organization", "date", "time", "cardinal", "money", "position"]
 
34
  FIXED_ENTITY_COLOR_MAP = {
35
  "person": "#10b981", # Green
36
  "country": "#3b82f6", # Blue
@@ -51,11 +52,8 @@ FIXED_CATEGORY_MAPPING = {
51
  "Numbers & Finance": ["money", "cardinal"]}
52
  REVERSE_FIXED_CATEGORY_MAPPING = {label: category for category, label_list in FIXED_CATEGORY_MAPPING.items() for label in label_list}
53
 
54
- # --- Default Custom Labels (Not used, but kept for full code compatibility) ---
55
- DEFAULT_CUSTOM_LABELS = "person, location, organization, product, date, time, event"
56
-
57
  # --- Dynamic Color Generator for Custom Labels ---
58
- COLOR_PALETTE = cycle(px.colors.qualitative.Alphabet)
59
 
60
  def extract_label(node_name):
61
  """Extracts the label from a node string like 'Text (Label)'."""
@@ -69,9 +67,16 @@ def remove_trailing_punctuation(text_string):
69
  def get_dynamic_color_map(active_labels, fixed_map):
70
  """Generates a color map, using fixed colors if available, otherwise dynamic colors."""
71
  color_map = {}
72
- if active_labels == FIXED_LABELS:
 
 
73
  return fixed_map
74
 
 
 
 
 
 
75
  for label in active_labels:
76
  if label in fixed_map:
77
  color_map[label] = fixed_map[label]
@@ -84,8 +89,7 @@ def highlight_entities(text, df_entities, entity_color_map):
84
  if df_entities.empty:
85
  return text
86
 
87
- # --- FIX: Ensure the DataFrame has a unique index before sorting/converting ---
88
- # Create a copy and reset index for safety, resolving potential errors in the to_dict step.
89
  df_entities = df_entities.copy().reset_index(drop=True)
90
 
91
  entities = df_entities.sort_values(by='start', ascending=False).to_dict('records')
@@ -537,11 +541,11 @@ with tab2:
537
  st.markdown("For any errors or inquiries, please contact us at [info@your-company.com](mailto:info@your-company.com)")
538
 
539
  # --- Model Loading ---
540
- @st.cache_resource
541
  def load_ner_model(labels):
542
  """Loads the GLiNER model and caches it."""
543
  try:
544
- # Note: Model loading is kept for fixed labels although not strictly required by the prompt
545
  return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True, num_gen_sequences=2, gen_constraints=labels)
546
  except Exception as e:
547
  # print(f"FATAL ERROR: Failed to load NER model: {e}")
@@ -566,7 +570,7 @@ DEFAULT_TEXT = (
566
  "general public by October 1st. The goal is to deploy the **Astra** v2 platform before the next solar eclipse event in 2026.")
567
 
568
  # -----------------------------------
569
- # --- Session State Initialization (Cleaned up for Text Area Input) ---
570
  if 'show_results' not in st.session_state: st.session_state.show_results = False
571
  if 'my_text_area' not in st.session_state: st.session_state.my_text_area = DEFAULT_TEXT
572
  if 'last_text' not in st.session_state: st.session_state.last_text = ""
@@ -574,7 +578,8 @@ if 'results_df' not in st.session_state: st.session_state.results_df = pd.DataFr
574
  if 'elapsed_time' not in st.session_state: st.session_state.elapsed_time = 0.0
575
  if 'topic_results' not in st.session_state: st.session_state.topic_results = None
576
  if 'active_labels_list' not in st.session_state: st.session_state.active_labels_list = FIXED_LABELS
577
- if 'is_custom_mode' not in st.session_state: st.session_state.is_custom_mode = False # Force Fixed Mode
 
578
  if 'num_topics_slider' not in st.session_state: st.session_state.num_topics_slider = 5
579
  if 'num_top_words_slider' not in st.session_state: st.session_state.num_top_words_slider = 10
580
  if 'last_num_topics' not in st.session_state: st.session_state.last_num_topics = None
@@ -593,7 +598,7 @@ def clear_text():
593
  # --- Revised Text Area Input ---
594
  st.markdown("## ✍️ Text Input for Analysis")
595
 
596
- word_limit = 10000
597
  text = st.text_area(
598
  f"Type or paste your text below (max {word_limit} words), and then press Ctrl + Enter",
599
  height=250,
@@ -603,43 +608,67 @@ text = st.text_area(
603
  word_count = len(text.split())
604
  st.markdown(f"**Word count:** {word_count}/{word_limit}")
605
 
606
- # Always Fixed Mode controls
607
  st.markdown("---")
608
- st.markdown("### Analysis Mode: **Fixed Entity Labels**")
609
- st.info(f"The analysis will use the pre-defined fixed label set: **{', '.join(FIXED_LABELS)}**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
 
 
611
  col_results, col_clear = st.columns([1, 1])
612
 
613
  with col_results:
614
- run_button = st.button("Analyze Text", key='run_results', use_container_width=True, type="primary")
615
 
616
  with col_clear:
617
  st.button("Clear text", on_click=clear_text, use_container_width=True)
618
 
619
- # --- Define Active Labels and Settings ---
620
- active_labels = FIXED_LABELS # Always fixed labels
621
- st.session_state.active_labels_list = active_labels
622
- current_num_topics = st.session_state.num_topics_slider
623
- current_num_top_words = st.session_state.num_top_words_slider
624
-
625
  # --- Results Trigger and Processing (Fixed for index error) ---
626
  if run_button:
627
  if text.strip() and word_count <= word_limit:
628
 
629
- # 1. Determine Active Labels and Mode (Already determined: Fixed Mode)
630
  active_labels = st.session_state.active_labels_list
631
 
632
  # Caching Logic: Check if we need to re-run the full process
633
  should_rerun_full_analysis = (
634
  text.strip() != st.session_state.last_text.strip() or
635
- active_labels != st.session_state.last_active_labels
636
  )
637
 
638
  if should_rerun_full_analysis:
639
  # 2. Rerunning Full Analysis
640
  CHUNKING_THRESHOLD = 500
641
  should_chunk = word_count > CHUNKING_THRESHOLD
642
- mode_msg = "fixed labels"
643
  if should_chunk:
644
  mode_msg += " with **chunking** for large text"
645
 
@@ -662,11 +691,20 @@ if run_button:
662
  df = pd.DataFrame(all_entities)
663
 
664
  if not df.empty:
665
- # 💥 FIX: Reset to a default, unique integer index
666
  df = df.reset_index(drop=True)
667
 
668
- # Force fixed category mapping
669
- df['category'] = df['label'].map(REVERSE_FIXED_CATEGORY_MAPPING).fillna('Other')
 
 
 
 
 
 
 
 
 
 
670
 
671
  df['text'] = df['text'].apply(remove_trailing_punctuation)
672
 
@@ -691,6 +729,9 @@ if run_button:
691
  elif word_count > word_limit:
692
  st.error(f"Text too long! Please limit your input to {word_limit} words.")
693
  st.session_state.show_results = False
 
 
 
694
  else:
695
  st.warning("Please enter some text to analyze.")
696
  st.session_state.show_results = False
@@ -719,8 +760,7 @@ if st.session_state.show_results:
719
  )
720
 
721
  # 1. Highlighted Text placed inside an Expander
722
- # Force Fixed Mode in display header
723
- with st.expander(f"### 1. Analyzed Text with Highlighted Entities (Fixed Mode)", expanded=False):
724
  st.markdown(
725
  highlight_entities(st.session_state.last_text, df, entity_color_map),
726
  unsafe_allow_html=True
@@ -735,10 +775,32 @@ if st.session_state.show_results:
735
  with tab_category_details:
736
  st.markdown("#### Detailed Entities Table (Grouped by Category)")
737
 
738
- # This section now ONLY executes the FIXED MODE logic
739
- unique_categories = list(FIXED_CATEGORY_MAPPING.keys())
740
- tabs_category = st.tabs(unique_categories)
741
- for category, tab in zip(unique_categories, tabs_category):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742
  df_category = df[df['category'] == category][['text', 'label', 'score', 'start', 'end']].sort_values(by='score', ascending=False)
743
  styled_df_category = color_score_gradient(df_category)
744
  with tab:
@@ -801,7 +863,6 @@ if st.session_state.show_results:
801
 
802
  # --- A. Network Graph Section ---
803
  with st.expander("🔗 Entity Co-occurrence Network Graph", expanded=True):
804
- # This calls the CORRECTED generate_network_graph function
805
  st.plotly_chart(generate_network_graph(df, st.session_state.last_text, entity_color_map), use_container_width=True)
806
 
807
  # --- B. Topic Modeling Section ---
@@ -871,7 +932,7 @@ if st.session_state.show_results:
871
  # 5. White-Label Configuration
872
  st.markdown("---")
873
  st.markdown("### 5. White-Label Report Configuration 🎨")
874
- default_report_title = "Fixed Entity Analysis Report"
875
  custom_report_title = st.text_input(
876
  "Type Your Report Title (for HTML Report), and then press Enter.",
877
  value=default_report_title
@@ -918,4 +979,5 @@ if st.session_state.show_results:
918
  file_name="ner_topic_full_report.html",
919
  mime="text/html",
920
  use_container_width=True
921
- )
 
 
31
 
32
  # --- Fixed Label Definitions and Mappings ---
33
  FIXED_LABELS = ["person", "country", "city", "organization", "date", "time", "cardinal", "money", "position"]
34
+ DEFAULT_CUSTOM_LABELS = "person, location, organization, product, date, time, event" # <-- REINSTATED
35
  FIXED_ENTITY_COLOR_MAP = {
36
  "person": "#10b981", # Green
37
  "country": "#3b82f6", # Blue
 
52
  "Numbers & Finance": ["money", "cardinal"]}
53
  REVERSE_FIXED_CATEGORY_MAPPING = {label: category for category, label_list in FIXED_CATEGORY_MAPPING.items() for label in label_list}
54
 
 
 
 
55
  # --- Dynamic Color Generator for Custom Labels ---
56
+ COLOR_PALETTE = cycle(px.colors.qualitative.Alphabet + px.colors.qualitative.Bold) # Use a larger palette
57
 
58
  def extract_label(node_name):
59
  """Extracts the label from a node string like 'Text (Label)'."""
 
67
  def get_dynamic_color_map(active_labels, fixed_map):
68
  """Generates a color map, using fixed colors if available, otherwise dynamic colors."""
69
  color_map = {}
70
+
71
+ # If the active labels exactly match the fixed set, use the fixed map
72
+ if set(active_labels) == set(fixed_map.keys()):
73
  return fixed_map
74
 
75
+ # Otherwise, generate a dynamic map, prioritizing fixed colors
76
+ # Ensure the color palette resets for consistency across sessions
77
+ global COLOR_PALETTE
78
+ COLOR_PALETTE = cycle(px.colors.qualitative.Alphabet + px.colors.qualitative.Bold)
79
+
80
  for label in active_labels:
81
  if label in fixed_map:
82
  color_map[label] = fixed_map[label]
 
89
  if df_entities.empty:
90
  return text
91
 
92
+ # Ensure the DataFrame has a unique index before sorting/converting
 
93
  df_entities = df_entities.copy().reset_index(drop=True)
94
 
95
  entities = df_entities.sort_values(by='start', ascending=False).to_dict('records')
 
541
  st.markdown("For any errors or inquiries, please contact us at [info@your-company.com](mailto:info@your-company.com)")
542
 
543
  # --- Model Loading ---
544
+ @st.cache_resourced
545
  def load_ner_model(labels):
546
  """Loads the GLiNER model and caches it."""
547
  try:
548
+ # GLiNER model is loaded with constraints based on the active labels list
549
  return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True, num_gen_sequences=2, gen_constraints=labels)
550
  except Exception as e:
551
  # print(f"FATAL ERROR: Failed to load NER model: {e}")
 
570
  "general public by October 1st. The goal is to deploy the **Astra** v2 platform before the next solar eclipse event in 2026.")
571
 
572
  # -----------------------------------
573
+ # --- Session State Initialization (Custom Label Reinstatement) ---
574
  if 'show_results' not in st.session_state: st.session_state.show_results = False
575
  if 'my_text_area' not in st.session_state: st.session_state.my_text_area = DEFAULT_TEXT
576
  if 'last_text' not in st.session_state: st.session_state.last_text = ""
 
578
  if 'elapsed_time' not in st.session_state: st.session_state.elapsed_time = 0.0
579
  if 'topic_results' not in st.session_state: st.session_state.topic_results = None
580
  if 'active_labels_list' not in st.session_state: st.session_state.active_labels_list = FIXED_LABELS
581
+ if 'is_custom_mode' not in st.session_state: st.session_state.is_custom_mode = "Fixed Labels" # Re-use for radio
582
+ if 'custom_labels_input' not in st.session_state: st.session_state.custom_labels_input = DEFAULT_CUSTOM_LABELS
583
  if 'num_topics_slider' not in st.session_state: st.session_state.num_topics_slider = 5
584
  if 'num_top_words_slider' not in st.session_state: st.session_state.num_top_words_slider = 10
585
  if 'last_num_topics' not in st.session_state: st.session_state.last_num_topics = None
 
598
  # --- Revised Text Area Input ---
599
  st.markdown("## ✍️ Text Input for Analysis")
600
 
601
+ word_limit = 2000
602
  text = st.text_area(
603
  f"Type or paste your text below (max {word_limit} words), and then press Ctrl + Enter",
604
  height=250,
 
608
  word_count = len(text.split())
609
  st.markdown(f"**Word count:** {word_count}/{word_limit}")
610
 
611
+ # --- Custom/Fixed Label Selector ---
612
  st.markdown("---")
613
+ st.markdown("### 🏷️ Entity Label Mode Selection")
614
+ mode = st.radio(
615
+ "Select Entity Recognition Mode:",
616
+ ["Fixed Labels", "Custom Labels"],
617
+ key='is_custom_mode',
618
+ horizontal=True,
619
+ help="Fixed Labels use a predefined set. Custom Labels let you define your own."
620
+ )
621
+
622
+ active_labels = []
623
+ if mode == "Fixed Labels":
624
+ active_labels = FIXED_LABELS
625
+ st.info(f"Fixed Labels active: **{', '.join(active_labels)}**")
626
+ else:
627
+ custom_labels_input = st.text_input(
628
+ "Enter your custom labels, separated by commas (e.g., product, feature, ticket_id):",
629
+ value=st.session_state.custom_labels_input,
630
+ key='custom_labels_input',
631
+ help="The labels must be non-empty and comma-separated."
632
+ )
633
+ # Clean and set active labels from user input
634
+ active_labels = [label.strip().lower() for label in custom_labels_input.split(',') if label.strip()]
635
+ if not active_labels:
636
+ st.error("Please enter at least one custom label.")
637
+ active_labels = [] # Prevents model run if empty
638
+ else:
639
+ st.info(f"Custom Labels active: **{', '.join(active_labels)}**")
640
+
641
+ st.session_state.active_labels_list = active_labels
642
+ current_num_topics = st.session_state.num_topics_slider
643
+ current_num_top_words = st.session_state.num_top_words_slider
644
 
645
+ # --- Buttons ---
646
  col_results, col_clear = st.columns([1, 1])
647
 
648
  with col_results:
649
+ run_button = st.button("Analyze Text", key='run_results', use_container_width=True, type="primary", disabled=not active_labels)
650
 
651
  with col_clear:
652
  st.button("Clear text", on_click=clear_text, use_container_width=True)
653
 
 
 
 
 
 
 
654
  # --- Results Trigger and Processing (Fixed for index error) ---
655
  if run_button:
656
  if text.strip() and word_count <= word_limit:
657
 
658
+ # 1. Determine Active Labels (Already done above, just referencing)
659
  active_labels = st.session_state.active_labels_list
660
 
661
  # Caching Logic: Check if we need to re-run the full process
662
  should_rerun_full_analysis = (
663
  text.strip() != st.session_state.last_text.strip() or
664
+ set(active_labels) != set(st.session_state.last_active_labels if st.session_state.last_active_labels else [])
665
  )
666
 
667
  if should_rerun_full_analysis:
668
  # 2. Rerunning Full Analysis
669
  CHUNKING_THRESHOLD = 500
670
  should_chunk = word_count > CHUNKING_THRESHOLD
671
+ mode_msg = "custom labels" if mode == "Custom Labels" else "fixed labels"
672
  if should_chunk:
673
  mode_msg += " with **chunking** for large text"
674
 
 
691
  df = pd.DataFrame(all_entities)
692
 
693
  if not df.empty:
 
694
  df = df.reset_index(drop=True)
695
 
696
+ # --- CATEGORY MAPPING ADJUSTMENT ---
697
+ # Assign fixed labels to their categories, and custom labels to 'User Defined'
698
+ def map_category(label):
699
+ if label in REVERSE_FIXED_CATEGORY_MAPPING:
700
+ return REVERSE_FIXED_CATEGORY_MAPPING[label]
701
+ elif label in active_labels and label not in FIXED_LABELS:
702
+ # This handles any truly custom labels entered by the user
703
+ return 'User Defined Entities'
704
+ else:
705
+ return 'Other'
706
+
707
+ df['category'] = df['label'].apply(map_category)
708
 
709
  df['text'] = df['text'].apply(remove_trailing_punctuation)
710
 
 
729
  elif word_count > word_limit:
730
  st.error(f"Text too long! Please limit your input to {word_limit} words.")
731
  st.session_state.show_results = False
732
+ elif not active_labels:
733
+ st.error("Please ensure your custom label input is not empty.")
734
+ st.session_state.show_results = False
735
  else:
736
  st.warning("Please enter some text to analyze.")
737
  st.session_state.show_results = False
 
760
  )
761
 
762
  # 1. Highlighted Text placed inside an Expander
763
+ with st.expander(f"### 1. Analyzed Text with Highlighted Entities ({mode} Mode)", expanded=False):
 
764
  st.markdown(
765
  highlight_entities(st.session_state.last_text, df, entity_color_map),
766
  unsafe_allow_html=True
 
775
  with tab_category_details:
776
  st.markdown("#### Detailed Entities Table (Grouped by Category)")
777
 
778
+ # Get all unique categories present in the data (Fixed + User Defined)
779
+ unique_categories = list(df['category'].unique())
780
+
781
+ # Ensure fixed categories appear first if present, followed by custom/other
782
+ ordered_categories = []
783
+
784
+ # Add fixed categories in defined order
785
+ for fixed_cat in FIXED_CATEGORY_MAPPING.keys():
786
+ if fixed_cat in unique_categories:
787
+ ordered_categories.append(fixed_cat)
788
+ unique_categories.remove(fixed_cat)
789
+
790
+ # Add User Defined and Other at the end
791
+ if 'User Defined Entities' in unique_categories:
792
+ ordered_categories.append('User Defined Entities')
793
+ unique_categories.remove('User Defined Entities')
794
+
795
+ if 'Other' in unique_categories:
796
+ ordered_categories.append('Other')
797
+ unique_categories.remove('Other')
798
+
799
+ # Add any remaining categories (shouldn't happen with map_category, but for safety)
800
+ ordered_categories.extend(unique_categories)
801
+
802
+ tabs_category = st.tabs(ordered_categories)
803
+ for category, tab in zip(ordered_categories, tabs_category):
804
  df_category = df[df['category'] == category][['text', 'label', 'score', 'start', 'end']].sort_values(by='score', ascending=False)
805
  styled_df_category = color_score_gradient(df_category)
806
  with tab:
 
863
 
864
  # --- A. Network Graph Section ---
865
  with st.expander("🔗 Entity Co-occurrence Network Graph", expanded=True):
 
866
  st.plotly_chart(generate_network_graph(df, st.session_state.last_text, entity_color_map), use_container_width=True)
867
 
868
  # --- B. Topic Modeling Section ---
 
932
  # 5. White-Label Configuration
933
  st.markdown("---")
934
  st.markdown("### 5. White-Label Report Configuration 🎨")
935
+ default_report_title = "Fixed Entity Analysis Report" if mode == "Fixed Labels" else "Custom Entity Analysis Report"
936
  custom_report_title = st.text_input(
937
  "Type Your Report Title (for HTML Report), and then press Enter.",
938
  value=default_report_title
 
979
  file_name="ner_topic_full_report.html",
980
  mime="text/html",
981
  use_container_width=True
982
+ )
983
+