Darshan03 commited on
Commit
fabac23
·
verified ·
1 Parent(s): 55d87c1

create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
+ import huggingface_hub
4
+ import torch
5
+
6
+ # Login to Hugging Face Hub
7
+ # huggingface_hub.login("")
8
+
9
+ # Load model and tokenizer
10
+ repo_id = "Darshan03/t5-model-small" # Replace with your hub username and model name
11
+ tokenizer = T5Tokenizer.from_pretrained(repo_id)
12
+ model = T5ForConditionalGeneration.from_pretrained(repo_id)
13
+
14
+ # Function to generate headline
15
+ def generate_headline(text, model, tokenizer):
16
+ MAX_TOKEN_LEN = 256
17
+ device = "cuda" if torch.cuda.is_available() else "cpu" # Check for GPU availability
18
+ model = model.to(device) # Move model to the device
19
+
20
+ input_ids = tokenizer(
21
+ text, return_tensors="pt", padding="max_length", truncation=True, max_length=MAX_TOKEN_LEN
22
+ ).input_ids.to(device) # Move input ids to the device
23
+
24
+ with torch.no_grad(): # Inference doesn't need gradients
25
+ outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
26
+
27
+ headline = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+ return headline
29
+
30
+ # Streamlit UI
31
+ st.title("T5 Headline Generator")
32
+
33
+ # Input text box
34
+ input_text = st.text_area("Enter your text:")
35
+
36
+ if st.button("Generate Headline"):
37
+ if input_text:
38
+ headline = generate_headline(input_text, model, tokenizer)
39
+ st.subheader("Generated Headline:")
40
+ st.write(headline)
41
+ else:
42
+ st.error("Please enter some text.")
43
+
44
+ # Refresh button
45
+ if st.button("Refresh"):
46
+ st.experimental_rerun()