File size: 2,482 Bytes
9987eb0
 
 
 
 
a6867c7
4111288
9987eb0
a6867c7
4111288
9987eb0
a6867c7
9987eb0
a6867c7
9987eb0
 
 
 
a6867c7
9987eb0
a6867c7
 
 
9987eb0
a6867c7
9987eb0
4111288
 
9987eb0
a6867c7
 
9987eb0
 
 
 
a6867c7
9987eb0
a6867c7
9987eb0
a6867c7
 
 
 
4111288
 
a6867c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9987eb0
a6867c7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import streamlit as st
import torch
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Streamlit page configuration
st.set_page_config(page_title="Review Keypoint Extractor (DistilBART-CNN-12-6)", page_icon="🔑")

# Define the model
model_name = "sshleifer/distilbart-cnn-12-6"

# Cache the model and tokenizer to avoid reloading
@st.cache_resource
def load_model_and_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return tokenizer, model, device

# Keypoint generation function
def generate_keypoint(review, max_new_tokens=64):
    tokenizer, model, device = load_model_and_tokenizer()
    
    start_time = time.time()
    
    # BART-specific prompt (no additional prompt engineering)
    prompt = review
    
    # Inference
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
    keypoint = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    
    # Post-process: Normalize "no key point" outputs
    if keypoint.lower() in ["none", "no keypoint", "no key point", "n/a", "na", "", "nothing"]:
        keypoint = "No key point"
    
    elapsed = time.time() - start_time
    return keypoint, elapsed

# Streamlit UI
st.title("🔑 Review Keypoint Extractor (DistilBART-CNN-12-6)")
st.write("Enter a product review below to extract its key points using the sshleifer/distilbart-cnn-12-6 model.")

# Input field for review
review = st.text_area("Product Review", placeholder="e.g., The Jackery power station is lightweight and charges quickly, but the battery life could be longer.")

# Button to generate keypoint
if st.button("Extract Keypoint"):
    if review.strip():
        with st.spinner("Generating keypoint..."):
            keypoint, elapsed = generate_keypoint(review)
            st.success(f"✅ Keypoint generated in {elapsed:.2f} seconds!")
            st.subheader("Results")
            st.write(f"**Review:** {review}")
            st.write(f"**Keypoint:** {keypoint}")
    else:
        st.error("⚠️ Please enter a valid review.")

# Footer
st.markdown("---")
st.markdown("Powered by [Hugging Face Transformers](https://huggingface.co/) and [Streamlit](https://streamlit.io/)")