Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,8 +2,9 @@ import streamlit as st
|
|
| 2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 3 |
|
| 4 |
model_names = {
|
|
|
|
| 5 |
"eluether1.3b":"EleutherAI/gpt-neo-1.3B",
|
| 6 |
-
|
| 7 |
|
| 8 |
|
| 9 |
def generate_texts(pipeline, input_text, **generator_args):
|
|
@@ -42,6 +43,7 @@ num_return_sequences = st.sidebar.slider("Num Return Sequences", min_value = 1,
|
|
| 42 |
num_beams = st.sidebar.slider("Num Beams", min_value = 2, max_value=6, value = 4)
|
| 43 |
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 90)
|
| 44 |
top_p = st.sidebar.slider("Top-p", min_value = 0.4, max_value=1.0, step = 0.05, value = 0.9)
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
if len(sent)<10:
|
|
@@ -56,6 +58,8 @@ output_sequences = generate_texts(pipelines[model_index],
|
|
| 56 |
num_beams=num_beams,
|
| 57 |
temperature=temperature,
|
| 58 |
top_k=top_k,
|
|
|
|
|
|
|
| 59 |
early_stopping=False,
|
| 60 |
top_p=top_p)
|
| 61 |
|
|
|
|
| 2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 3 |
|
| 4 |
model_names = {
|
| 5 |
+
"gpt2-medium":"gpt2-medium",
|
| 6 |
"eluether1.3b":"EleutherAI/gpt-neo-1.3B",
|
| 7 |
+
}
|
| 8 |
|
| 9 |
|
| 10 |
def generate_texts(pipeline, input_text, **generator_args):
|
|
|
|
| 43 |
num_beams = st.sidebar.slider("Num Beams", min_value = 2, max_value=6, value = 4)
|
| 44 |
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 90)
|
| 45 |
top_p = st.sidebar.slider("Top-p", min_value = 0.4, max_value=1.0, step = 0.05, value = 0.9)
|
| 46 |
+
repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value = 0.45 max_value=2.0, step = 0.1, value = 1.2)
|
| 47 |
|
| 48 |
|
| 49 |
if len(sent)<10:
|
|
|
|
| 58 |
num_beams=num_beams,
|
| 59 |
temperature=temperature,
|
| 60 |
top_k=top_k,
|
| 61 |
+
no_repeat_ngram_size=2,
|
| 62 |
+
repetition_penalty = repetition_penalty,
|
| 63 |
early_stopping=False,
|
| 64 |
top_p=top_p)
|
| 65 |
|