Spaces:
Sleeping
Sleeping
File size: 2,482 Bytes
9987eb0 a6867c7 4111288 9987eb0 a6867c7 4111288 9987eb0 a6867c7 9987eb0 a6867c7 9987eb0 a6867c7 9987eb0 a6867c7 9987eb0 a6867c7 9987eb0 4111288 9987eb0 a6867c7 9987eb0 a6867c7 9987eb0 a6867c7 9987eb0 a6867c7 4111288 a6867c7 9987eb0 a6867c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import streamlit as st
import torch
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Streamlit page configuration
st.set_page_config(page_title="Review Keypoint Extractor (DistilBART-CNN-12-6)", page_icon="🔑")
# Define the model
model_name = "sshleifer/distilbart-cnn-12-6"
# Cache the model and tokenizer to avoid reloading
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
return tokenizer, model, device
# Keypoint generation function
def generate_keypoint(review, max_new_tokens=64):
tokenizer, model, device = load_model_and_tokenizer()
start_time = time.time()
# BART-specific prompt (no additional prompt engineering)
prompt = review
# Inference
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
keypoint = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Post-process: Normalize "no key point" outputs
if keypoint.lower() in ["none", "no keypoint", "no key point", "n/a", "na", "", "nothing"]:
keypoint = "No key point"
elapsed = time.time() - start_time
return keypoint, elapsed
# Streamlit UI
st.title("🔑 Review Keypoint Extractor (DistilBART-CNN-12-6)")
st.write("Enter a product review below to extract its key points using the sshleifer/distilbart-cnn-12-6 model.")
# Input field for review
review = st.text_area("Product Review", placeholder="e.g., The Jackery power station is lightweight and charges quickly, but the battery life could be longer.")
# Button to generate keypoint
if st.button("Extract Keypoint"):
if review.strip():
with st.spinner("Generating keypoint..."):
keypoint, elapsed = generate_keypoint(review)
st.success(f"✅ Keypoint generated in {elapsed:.2f} seconds!")
st.subheader("Results")
st.write(f"**Review:** {review}")
st.write(f"**Keypoint:** {keypoint}")
else:
st.error("⚠️ Please enter a valid review.")
# Footer
st.markdown("---")
st.markdown("Powered by [Hugging Face Transformers](https://huggingface.co/) and [Streamlit](https://streamlit.io/)") |