Sage-Text-1B / app.py
itriedcoding's picture
Update app.py for better compatibility
8a43bf1 verified
Raw
History Blame Contribute Delete
4.39 kB
"""
PlainEnglish-1B Gradio Demo App
Interactive text generation interface for HuggingFace Spaces and ModelScope Studio.
"""
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
HF_MODEL_ID = "itriedcoding/PlainEnglish-1B"
MS_MODEL_ID = "NeuraAI/PlainEnglish-1B"
MODEL_ID = MS_MODEL_ID if os.environ.get("MODELSCOPE_API_TOKEN") else HF_MODEL_ID
print(f"Loading model from: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True,
)
model.eval()
print("Model loaded successfully!")
def generate_text(
prompt,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
):
if not prompt.strip():
return "Please enter a prompt."
inputs = tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
repetition_penalty=float(repetition_penalty),
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated
def count_parameters():
total = sum(p.numel() for p in model.parameters())
return f"{total:,} ({total/1e6:.1f}M)"
css = """
footer { display: none !important; }
.gradio-container { max-width: 800px; margin: auto; }
"""
with gr.Blocks(css=css, title="PlainEnglish-1B") as demo:
gr.Markdown(
"""
# PlainEnglish-1B
A 1B parameter text generation model fine-tuned for clear, plain English.
Enter a prompt and adjust parameters to generate text.
"""
)
with gr.Row():
with gr.Column(scale=3):
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your text prompt here...",
lines=4,
)
with gr.Column(scale=1):
max_tokens = gr.Slider(
minimum=10,
maximum=500,
value=200,
step=10,
label="Max New Tokens",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p",
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=5,
label="Top-k",
)
rep_penalty = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.05,
label="Repetition Penalty",
)
generate_btn = gr.Button("Generate Text", variant="primary")
output_text = gr.Textbox(
label="Generated Text",
lines=8,
)
gr.Markdown(f"**Model Parameters**: {count_parameters()}")
generate_btn.click(
fn=generate_text,
inputs=[
prompt_input,
max_tokens,
temperature,
top_p,
top_k,
rep_penalty,
],
outputs=output_text,
)
prompt_input.submit(
fn=generate_text,
inputs=[
prompt_input,
max_tokens,
temperature,
top_p,
top_k,
rep_penalty,
],
outputs=output_text,
)
gr.Examples(
examples=[
["The meaning of life is"],
["In the year 2025, artificial intelligence"],
["The best way to learn programming"],
["Scientists recently discovered that"],
["Once upon a time, in a small village"],
],
inputs=prompt_input,
)
if __name__ == "__main__":
demo.launch()