import streamlit as st import numpy as np import tensorflow as tf import re from pathlib import Path # Set page config st.set_page_config( page_title="SkimLit - Abstract Classifier", page_icon="πŸ“„", layout="wide", ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model_and_encoder(): """Load the trained model and sentence encoder""" try: from sentence_transformers import SentenceTransformer import urllib.request import os script_dir = Path(__file__).parent model_path = script_dir / 'model_5.keras' # Load sentence encoder encoder = SentenceTransformer("all-MiniLM-L6-v2") # Load the model - try local first, then download if model_path.exists(): model = tf.keras.models.load_model(str(model_path)) else: st.info("Downloading model... (first time only)") # Download from HF Hub model_url = "https://huggingface.co/BILALfym/skimlit-model/resolve/main/model_5.keras" urllib.request.urlretrieve(model_url, str(model_path)) model = tf.keras.models.load_model(str(model_path)) return model, encoder except Exception as e: st.error(f"Error loading: {e}") return None, None def encode_line_number(line_number, max_value=15): """Encode line number as a one-hot vector""" vec = np.zeros(max_value) if line_number < max_value: vec[line_number] = 1 return vec def encode_total_lines(total_lines, max_value=20): """Encode total lines as a one-hot vector""" vec = np.zeros(max_value) if total_lines < max_value: vec[total_lines] = 1 return vec def predict_labels(sentences, model, encoder): """Predict labels for sentences""" if not model or not encoder: return [] predictions = [] total_sentences = len(sentences) # Encode all sentences at once try: embeddings = encoder.encode(sentences, batch_size=32, show_progress_bar=False) except Exception as e: st.error(f"Error encoding sentences: {e}") return [] for idx, sentence in enumerate(sentences): try: # Prepare character input (space-separated chars) char_text = " ".join(list(sentence)) # Get embedding for this sentence token_embedding = embeddings[idx:idx+1].astype(np.float32) # Prepare positional inputs line_input = encode_line_number(idx, max_value=15).astype(np.float32) total_input = encode_total_lines(total_sentences, max_value=20).astype(np.float32) # Predict - convert all to TensorFlow tensors with correct dtypes pred = model.predict( { 'token_inputs': tf.constant(token_embedding, dtype=tf.float32), 'char_inputs': tf.constant([char_text], dtype=tf.string), 'line_number_inputs': tf.constant([line_input], dtype=tf.float32), 'total_lines_inputs': tf.constant([total_input], dtype=tf.float32) }, verbose=0 ) pred_probs = pred[0] pred_label = np.argmax(pred_probs) confidence = np.max(pred_probs) predictions.append({ 'sentence': sentence, 'label_id': int(pred_label), 'confidence': float(confidence), 'probabilities': [float(p) for p in pred_probs] }) except Exception as e: st.warning(f"Error predicting: {str(e)[:80]}") continue return predictions def get_label_name(label_id): """Map label ID to name β€” ordre alphabΓ©tique sklearn LabelEncoder""" labels = ['Background', 'Conclusions', 'Methods', 'Objective', 'Results'] return labels[label_id] if 0 <= label_id < len(labels) else 'Unknown' def get_emoji(label_name): """Get emoji for label""" emojis = { 'Background': 'πŸ“š', 'Objective': '🎯', 'Methods': 'πŸ”¬', 'Results': 'πŸ“Š', 'Conclusions': 'βœ…' } return emojis.get(label_name, 'πŸ“„') # Main app st.title("πŸ“„ SkimLit - Abstract Section Classifier") st.write("Organize your scientific abstract into structured sections") # Load model model, encoder = load_model_and_encoder() if model is None or encoder is None: st.stop() # Input section st.markdown("---") input_method = st.radio( "Choose input:", ["Sample abstract", "Enter your text"] ) if input_method == "Sample abstract": sample = """Background: Cardiovascular disease remains a leading cause of mortality globally. Early detection through biomarkers can improve patient outcomes. Objective: This study aims to identify novel cardiovascular biomarkers. Methods: We conducted a prospective cohort study of 500 participants over 5 years, collecting blood samples for mass spectrometry analysis. Results: We identified three novel biomarkers with 85% sensitivity and 90% specificity for early cardiovascular disease detection. Conclusions: These biomarkers show significant promise and warrant further validation in independent cohorts.""" text = st.text_area("Abstract:", value=sample, height=200) else: text = st.text_area( "Paste your abstract:", height=200, placeholder="Enter scientific abstract..." ) # Classify button if st.button("πŸš€ Classify", use_container_width=True): if text.strip(): sentences = re.split(r'(?<=[.!?])\s+', text.strip()) sentences = [s.strip() for s in sentences if s.strip()] if sentences: with st.spinner("Classifying..."): predictions = predict_labels(sentences, model, encoder) if predictions: st.markdown("---") st.subheader("πŸ“‹ Classified Abstract") # Group sentences by label sections = { 'Background': [], 'Objective': [], 'Methods': [], 'Results': [], 'Conclusions': [] } for pred in predictions: label = get_label_name(pred['label_id']) sections[label].append(pred['sentence']) # Display sections in order section_order = ['Background', 'Objective', 'Methods', 'Results', 'Conclusions'] for section_name in section_order: sentences_in_section = sections[section_name] if sentences_in_section: emoji = get_emoji(section_name) st.markdown(f"### {emoji} {section_name}") # Join sentences in this section section_text = " ".join(sentences_in_section) # Display with styling st.markdown(f"
{section_text}
", unsafe_allow_html=True) else: st.error("Could not generate predictions.") else: st.warning("No sentences found.") else: st.warning("Please enter some text.") st.markdown("---") st.caption("πŸ”¬ SkimLit | Scientific Abstract Classifier")