Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -22,23 +22,23 @@ os.system('python -m spacy download en_core_web_sm')
|
|
| 22 |
nlp = spacy.load("en_core_web_sm")
|
| 23 |
|
| 24 |
# Function for generating text and tokenizing
|
| 25 |
-
def historical_generation(prompt, max_new_tokens=600):
|
| 26 |
prompt = f"### Text ###\n{prompt}"
|
| 27 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
|
| 28 |
input_ids = inputs["input_ids"].to(device)
|
| 29 |
attention_mask = inputs["attention_mask"].to(device)
|
| 30 |
|
| 31 |
-
# Generate text
|
| 32 |
output = model.generate(
|
| 33 |
input_ids,
|
| 34 |
attention_mask=attention_mask,
|
| 35 |
max_new_tokens=max_new_tokens,
|
| 36 |
pad_token_id=tokenizer.eos_token_id,
|
| 37 |
-
top_k=
|
| 38 |
-
temperature=
|
| 39 |
-
top_p=
|
| 40 |
do_sample=True,
|
| 41 |
-
repetition_penalty=
|
| 42 |
bos_token_id=tokenizer.bos_token_id,
|
| 43 |
eos_token_id=tokenizer.eos_token_id
|
| 44 |
)
|
|
@@ -53,11 +53,11 @@ def historical_generation(prompt, max_new_tokens=600):
|
|
| 53 |
# Tokenize the generated text
|
| 54 |
tokens = tokenizer.tokenize(generated_text)
|
| 55 |
|
| 56 |
-
# Create highlighted text output
|
| 57 |
highlighted_text = []
|
| 58 |
for token in tokens:
|
| 59 |
clean_token = token.replace("Ġ", "") # Remove "Ġ"
|
| 60 |
-
token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0]
|
| 61 |
highlighted_text.append((clean_token, token_type))
|
| 62 |
|
| 63 |
return highlighted_text, generated_text # Return both tokenized and raw generated text
|
|
@@ -85,8 +85,10 @@ def generate_dependency_parse(generated_text):
|
|
| 85 |
return html_generated
|
| 86 |
|
| 87 |
# Full interface combining text generation and analysis, split across steps
|
| 88 |
-
def full_interface(prompt, max_new_tokens):
|
| 89 |
-
generated_highlight, generated_text = historical_generation(
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# Dependency parse of input text
|
| 92 |
tokens_input, pos_count_input, html_input = text_analysis(prompt)
|
|
@@ -101,7 +103,13 @@ def reset_interface():
|
|
| 101 |
# Gradio interface components
|
| 102 |
with gr.Blocks() as iface:
|
| 103 |
prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for historical text generation...", lines=3)
|
|
|
|
|
|
|
| 104 |
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=50, value=600)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# Output components
|
| 107 |
highlighted_text = gr.HighlightedText(label="Generated Historical Text", combine_adjacent=True, show_legend=True)
|
|
@@ -126,7 +134,7 @@ with gr.Blocks() as iface:
|
|
| 126 |
generate_button = gr.Button(value="Generate Text and Initial Outputs")
|
| 127 |
generate_button.click(
|
| 128 |
full_interface,
|
| 129 |
-
inputs=[prompt, max_new_tokens],
|
| 130 |
outputs=[highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
|
| 131 |
)
|
| 132 |
|
|
|
|
| 22 |
nlp = spacy.load("en_core_web_sm")
|
| 23 |
|
| 24 |
# Function for generating text and tokenizing
|
| 25 |
+
def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
|
| 26 |
prompt = f"### Text ###\n{prompt}"
|
| 27 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
|
| 28 |
input_ids = inputs["input_ids"].to(device)
|
| 29 |
attention_mask = inputs["attention_mask"].to(device)
|
| 30 |
|
| 31 |
+
# Generate text with customizable parameters
|
| 32 |
output = model.generate(
|
| 33 |
input_ids,
|
| 34 |
attention_mask=attention_mask,
|
| 35 |
max_new_tokens=max_new_tokens,
|
| 36 |
pad_token_id=tokenizer.eos_token_id,
|
| 37 |
+
top_k=top_k,
|
| 38 |
+
temperature=temperature,
|
| 39 |
+
top_p=top_p,
|
| 40 |
do_sample=True,
|
| 41 |
+
repetition_penalty=repetition_penalty,
|
| 42 |
bos_token_id=tokenizer.bos_token_id,
|
| 43 |
eos_token_id=tokenizer.eos_token_id
|
| 44 |
)
|
|
|
|
| 53 |
# Tokenize the generated text
|
| 54 |
tokens = tokenizer.tokenize(generated_text)
|
| 55 |
|
| 56 |
+
# Create highlighted text output, remove "Ġ" from both the token and token_type
|
| 57 |
highlighted_text = []
|
| 58 |
for token in tokens:
|
| 59 |
clean_token = token.replace("Ġ", "") # Remove "Ġ"
|
| 60 |
+
token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
|
| 61 |
highlighted_text.append((clean_token, token_type))
|
| 62 |
|
| 63 |
return highlighted_text, generated_text # Return both tokenized and raw generated text
|
|
|
|
| 85 |
return html_generated
|
| 86 |
|
| 87 |
# Full interface combining text generation and analysis, split across steps
|
| 88 |
+
def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty):
|
| 89 |
+
generated_highlight, generated_text = historical_generation(
|
| 90 |
+
prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty
|
| 91 |
+
)
|
| 92 |
|
| 93 |
# Dependency parse of input text
|
| 94 |
tokens_input, pos_count_input, html_input = text_analysis(prompt)
|
|
|
|
| 103 |
# Gradio interface components
|
| 104 |
with gr.Blocks() as iface:
|
| 105 |
prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for historical text generation...", lines=3)
|
| 106 |
+
|
| 107 |
+
# Slider for model parameters
|
| 108 |
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=50, value=600)
|
| 109 |
+
top_k = gr.Slider(label="Top-k Sampling", minimum=1, maximum=100, step=1, value=50)
|
| 110 |
+
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.1, value=0.7)
|
| 111 |
+
top_p = gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, step=0.05, value=0.95)
|
| 112 |
+
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.1, value=1.0)
|
| 113 |
|
| 114 |
# Output components
|
| 115 |
highlighted_text = gr.HighlightedText(label="Generated Historical Text", combine_adjacent=True, show_legend=True)
|
|
|
|
| 134 |
generate_button = gr.Button(value="Generate Text and Initial Outputs")
|
| 135 |
generate_button.click(
|
| 136 |
full_interface,
|
| 137 |
+
inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty],
|
| 138 |
outputs=[highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
|
| 139 |
)
|
| 140 |
|