File size: 4,392 Bytes
e3d95ba
 
 
 
 
8a43bf1
e3d95ba
 
 
 
8a43bf1
 
e3d95ba
8a43bf1
 
 
e3d95ba
 
 
 
8a43bf1
e3d95ba
 
 
8a43bf1
e3d95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
8a43bf1
 
e3d95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
PlainEnglish-1B Gradio Demo App
Interactive text generation interface for HuggingFace Spaces and ModelScope Studio.
"""

import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

HF_MODEL_ID = "itriedcoding/PlainEnglish-1B"
MS_MODEL_ID = "NeuraAI/PlainEnglish-1B"

MODEL_ID = MS_MODEL_ID if os.environ.get("MODELSCOPE_API_TOKEN") else HF_MODEL_ID

print(f"Loading model from: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    device_map="auto",
    trust_remote_code=True,
)
model.eval()
print("Model loaded successfully!")


def generate_text(
    prompt,
    max_new_tokens=200,
    temperature=0.7,
    top_p=0.9,
    top_k=50,
    repetition_penalty=1.1,
):
    if not prompt.strip():
        return "Please enter a prompt."

    inputs = tokenizer(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=int(max_new_tokens),
            temperature=float(temperature),
            top_p=float(top_p),
            top_k=int(top_k),
            repetition_penalty=float(repetition_penalty),
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated


def count_parameters():
    total = sum(p.numel() for p in model.parameters())
    return f"{total:,} ({total/1e6:.1f}M)"


css = """
footer { display: none !important; }
.gradio-container { max-width: 800px; margin: auto; }
"""

with gr.Blocks(css=css, title="PlainEnglish-1B") as demo:
    gr.Markdown(
        """
        # PlainEnglish-1B
        A 1B parameter text generation model fine-tuned for clear, plain English.
        Enter a prompt and adjust parameters to generate text.
        """
    )

    with gr.Row():
        with gr.Column(scale=3):
            prompt_input = gr.Textbox(
                label="Prompt",
                placeholder="Enter your text prompt here...",
                lines=4,
            )
        with gr.Column(scale=1):
            max_tokens = gr.Slider(
                minimum=10,
                maximum=500,
                value=200,
                step=10,
                label="Max New Tokens",
            )
            temperature = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=0.7,
                step=0.1,
                label="Temperature",
            )
            top_p = gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.9,
                step=0.05,
                label="Top-p",
            )
            top_k = gr.Slider(
                minimum=1,
                maximum=100,
                value=50,
                step=5,
                label="Top-k",
            )
            rep_penalty = gr.Slider(
                minimum=1.0,
                maximum=2.0,
                value=1.1,
                step=0.05,
                label="Repetition Penalty",
            )

    generate_btn = gr.Button("Generate Text", variant="primary")

    output_text = gr.Textbox(
        label="Generated Text",
        lines=8,
    )

    gr.Markdown(f"**Model Parameters**: {count_parameters()}")

    generate_btn.click(
        fn=generate_text,
        inputs=[
            prompt_input,
            max_tokens,
            temperature,
            top_p,
            top_k,
            rep_penalty,
        ],
        outputs=output_text,
    )

    prompt_input.submit(
        fn=generate_text,
        inputs=[
            prompt_input,
            max_tokens,
            temperature,
            top_p,
            top_k,
            rep_penalty,
        ],
        outputs=output_text,
    )

    gr.Examples(
        examples=[
            ["The meaning of life is"],
            ["In the year 2025, artificial intelligence"],
            ["The best way to learn programming"],
            ["Scientists recently discovered that"],
            ["Once upon a time, in a small village"],
        ],
        inputs=prompt_input,
    )

if __name__ == "__main__":
    demo.launch()