Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import trafilatura | |
| import numpy as np | |
| import pandas as pd | |
| from tensorflow.lite.python.interpreter import Interpreter | |
| import requests | |
| # File paths | |
| MODEL_PATH = "./model.tflite" | |
| VOCAB_PATH = "./vocab.txt" | |
| LABELS_PATH = "./taxonomy_v2.csv" | |
| def load_vocab(): | |
| with open(VOCAB_PATH, 'r') as f: | |
| vocab = [line.strip() for line in f] | |
| return vocab | |
| def load_labels(): | |
| # Load labels from the CSV file | |
| taxonomy = pd.read_csv(LABELS_PATH) | |
| taxonomy["ID"] = taxonomy["ID"].astype(int) | |
| labels_dict = taxonomy.set_index("ID")["Topic"].to_dict() | |
| return labels_dict | |
| def load_model(): | |
| try: | |
| # Use TensorFlow Lite Interpreter | |
| interpreter = Interpreter(model_path=MODEL_PATH) | |
| interpreter.allocate_tensors() | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details() | |
| return interpreter, input_details, output_details | |
| except Exception as e: | |
| st.error(f"Failed to load the model: {e}") | |
| raise | |
| def preprocess_text(text, vocab, max_length=128): | |
| # Tokenize the text using the provided vocabulary | |
| words = text.split()[:max_length] # Split and truncate | |
| token_ids = [vocab.index(word) if word in vocab else vocab.index("[UNK]") for word in words] | |
| token_ids = np.array(token_ids + [0] * (max_length - len(token_ids)), dtype=np.int32) # Pad to max length | |
| attention_mask = np.array([1 if i < len(words) else 0 for i in range(max_length)], dtype=np.int32) | |
| token_type_ids = np.zeros_like(attention_mask, dtype=np.int32) | |
| return token_ids[np.newaxis, :], attention_mask[np.newaxis, :], token_type_ids[np.newaxis, :] | |
| def classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids): | |
| interpreter.set_tensor(input_details[0]["index"], input_word_ids) | |
| interpreter.set_tensor(input_details[1]["index"], input_mask) | |
| interpreter.set_tensor(input_details[2]["index"], input_type_ids) | |
| interpreter.invoke() | |
| output = interpreter.get_tensor(output_details[0]["index"]) | |
| return output[0] | |
| def fetch_url_content(url): | |
| headers = { | |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36", | |
| "Accept-Language": "en-US,en;q=0.9", | |
| "Accept-Encoding": "gzip, deflate, br", | |
| } | |
| try: | |
| response = requests.get(url, headers=headers, cookies={}, timeout=10) | |
| if response.status_code == 200: | |
| return response.text | |
| else: | |
| st.error(f"Failed to fetch content. Status code: {response.status_code}") | |
| return None | |
| except Exception as e: | |
| st.error(f"Error fetching content: {e}") | |
| return None | |
| # Streamlit app | |
| st.title("Topic Classification from URL") | |
| url = st.text_input("Enter a URL:", "") | |
| if url: | |
| st.write("Extracting content from the URL...") | |
| raw_content = fetch_url_content(url) | |
| if raw_content: | |
| content = trafilatura.extract(raw_content) | |
| if content: | |
| st.write("Content extracted successfully!") | |
| st.write(content[:500]) # Display a snippet of the content | |
| # Load resources | |
| vocab = load_vocab() | |
| labels_dict = load_labels() | |
| interpreter, input_details, output_details = load_model() | |
| # Preprocess content and classify | |
| input_word_ids, input_mask, input_type_ids = preprocess_text(content, vocab) | |
| predictions = classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids) | |
| # Display classification | |
| st.write("Topic Classification:") | |
| sorted_indices = np.argsort(predictions)[::-1][:5] # Top 5 topics | |
| for idx in sorted_indices: | |
| topic = labels_dict.get(idx, "Unknown Topic") | |
| st.write(f"ID: {idx} - Topic: {topic} - Score: {predictions[idx]:.4f}") | |
| else: | |
| st.error("Unable to extract content from the fetched HTML.") | |
| else: | |
| st.error("Failed to fetch the URL.") | |