Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| import gdown | |
| import os | |
| # Set the title of the Streamlit app | |
| st.title("Text Summarization with Fine-Tuned BART") | |
| # Function to download the model from Google Drive | |
| def download_model_from_drive(file_id, dest_path): | |
| url = f'https://drive.google.com/uc?id={file_id}' | |
| try: | |
| gdown.download(url, dest_path, quiet=False) | |
| st.success(f"Downloaded {dest_path}") | |
| except Exception as e: | |
| st.error(f"Error downloading {dest_path}: {e}") | |
| # Ensure the model directory exists | |
| model_dir = 'model' | |
| if not os.path.exists(model_dir): | |
| os.makedirs(model_dir) | |
| # File IDs for your model components | |
| file_ids = { | |
| 'model': '1-V2bEtPR9Y3iBXK9zOR-qM5y9hKiQUnF', | |
| 'config': '1-T2etSP_k_3j5LzunWq8viKGQCQ5RMr_', | |
| 'tokenizer': '1-cRYNPWqlNNGRxeztympRRfVuy3hWuMY', | |
| 'vocab': '1-t9AhomeH7YIIpAqCGTok8wjvl0tml0F', | |
| 'merges': '1-l77_KEdK7GBFjMX_6UXGE-ZTGDraaDm' | |
| } | |
| # Download the model files | |
| with st.spinner("Downloading model..."): | |
| download_model_from_drive(file_ids['model'], os.path.join(model_dir, 'pytorch_model.bin')) | |
| download_model_from_drive(file_ids['config'], os.path.join(model_dir, 'config.json')) | |
| download_model_from_drive(file_ids['tokenizer'], os.path.join(model_dir, 'tokenizer.json')) | |
| download_model_from_drive(file_ids['vocab'], os.path.join(model_dir, 'vocab.json')) | |
| download_model_from_drive(file_ids['merges'], os.path.join(model_dir, 'merges.txt')) | |
| # Load the model and tokenizer | |
| def load_model_and_tokenizer(): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) | |
| return tokenizer, model | |
| except Exception as e: | |
| st.error(f"Error loading model or tokenizer: {e}") | |
| return None, None | |
| tokenizer, model = load_model_and_tokenizer() | |
| # Input text from user | |
| input_text = st.text_area("Enter the text to summarize:") | |
| if st.button("Summarize"): | |
| if input_text: | |
| if tokenizer and model: | |
| try: | |
| # Tokenize the input text | |
| inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
| # Perform summarization | |
| with torch.no_grad(): | |
| summary_ids = model.generate(inputs['input_ids'], max_length=150, num_beams=4, early_stopping=True) | |
| # Decode the summary | |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| st.write(f"Summary: {summary}") | |
| except Exception as e: | |
| st.error(f"Error during summarization: {e}") | |
| else: | |
| st.error("Model or tokenizer not loaded.") | |
| else: | |
| st.write("Please enter some text to summarize.") | |