Maddy90 commited on
Commit
115ab92
·
verified ·
1 Parent(s): df93cd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -14
app.py CHANGED
@@ -1,26 +1,59 @@
1
  import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
  # Load model and tokenizer
5
- model_name = "gpt2"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
- def generate_blog_post(topic):
10
- prompt = f"Write a detailed blog post about {topic}."
11
- inputs = tokenizer.encode(prompt, return_tensors="pt")
12
- outputs = model.generate(inputs, max_length=512, num_return_sequences=1)
13
- blog_post = tokenizer.decode(outputs[0], skip_special_tokens=True)
14
- return blog_post
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Streamlit interface
17
  st.title("Blog Post Generator")
18
- st.write("Enter a topic to generate a detailed blog post.")
19
 
20
- topic = st.text_input("Topic", "")
21
- if st.button("Generate Blog Post"):
22
- if topic:
 
 
23
  blog_post = generate_blog_post(topic)
 
24
  st.write(blog_post)
25
- else:
26
- st.write("Please enter a topic to generate a blog post.")
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
 
4
  # Load model and tokenizer
5
+ model_name = "gpt2-large" # You can change this to a larger GPT model if needed
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
+ # Set up zero-shot classification pipeline
10
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
11
+
12
+ # Function to generate blog post
13
+ def generate_blog_post(topic, max_length=500):
14
+ prompt = f"Write a blog post about {topic}:\n\n"
15
+ inputs = tokenizer(prompt, return_tensors="pt")
16
+
17
+ outputs = model.generate(
18
+ inputs.input_ids,
19
+ max_length=max_length,
20
+ num_return_sequences=1,
21
+ no_repeat_ngram_size=2,
22
+ top_k=50,
23
+ top_p=0.95,
24
+ temperature=0.7,
25
+ )
26
+
27
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+ return generated_text.replace(prompt, "")
29
+
30
+ # Function to classify blog post
31
+ def classify_blog_post(text, labels):
32
+ result = classifier(text, labels)
33
+ return result
34
 
35
  # Streamlit interface
36
  st.title("Blog Post Generator")
 
37
 
38
+ topic = st.text_input("Enter a topic for your blog post:")
39
+ generate_button = st.button("Generate Blog Post")
40
+
41
+ if generate_button and topic:
42
+ with st.spinner("Generating blog post..."):
43
  blog_post = generate_blog_post(topic)
44
+ st.subheader("Generated Blog Post")
45
  st.write(blog_post)
46
+
47
+ # Classify the generated blog post
48
+ st.subheader("Blog Post Classification")
49
+ labels = ["Technology", "Travel", "Food", "Health", "Finance"]
50
+ classification = classify_blog_post(blog_post, labels)
51
+
52
+ for label, score in zip(classification['labels'], classification['scores']):
53
+ st.write(f"{label}: {score:.2f}")
54
+
55
+ st.sidebar.title("About")
56
+ st.sidebar.info(
57
+ "This app generates a blog post on a given topic using a large GPT model. "
58
+ "It also classifies the generated post using zero-shot classification."
59
+ )