|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
def load_model(): |
|
|
"""Load the SmallLM model and tokenizer""" |
|
|
global model, tokenizer |
|
|
|
|
|
try: |
|
|
print("Loading SmallLM model...") |
|
|
model_name = "XsoraS/SmallLM" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
return "Model loaded successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error loading model: {str(e)}" |
|
|
print(error_msg) |
|
|
return error_msg |
|
|
|
|
|
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9): |
|
|
"""Generate text using the loaded model""" |
|
|
global model, tokenizer |
|
|
|
|
|
if model is None or tokenizer is None: |
|
|
return "Please load the model first!" |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = tokenizer.encode(prompt, return_tensors="pt") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
inputs = inputs.to(model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
inputs, |
|
|
max_length=max_length, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
num_return_sequences=1 |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
return generated_text[len(prompt):].strip() |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error generating text: {str(e)}" |
|
|
|
|
|
def clear_text(): |
|
|
"""Clear the input and output""" |
|
|
return "", "" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="SmallLM Demo", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# π€ SmallLM Inference Demo") |
|
|
gr.Markdown("Simple demo for XsoraS/SmallLM text generation") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
load_btn = gr.Button("π Load Model", variant="primary") |
|
|
status = gr.Textbox( |
|
|
label="Status", |
|
|
value="Click 'Load Model' to start", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
prompt_input = gr.Textbox( |
|
|
label="Enter your prompt:", |
|
|
placeholder="Once upon a time...", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
max_length = gr.Slider( |
|
|
label="Max Length", |
|
|
minimum=10, |
|
|
maximum=500, |
|
|
value=100, |
|
|
step=10 |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
label="Temperature", |
|
|
minimum=0.1, |
|
|
maximum=2.0, |
|
|
value=0.7, |
|
|
step=0.1 |
|
|
) |
|
|
top_p = gr.Slider( |
|
|
label="Top P", |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.9, |
|
|
step=0.05 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
generate_btn = gr.Button("β¨ Generate", variant="primary") |
|
|
clear_btn = gr.Button("ποΈ Clear") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
output = gr.Textbox( |
|
|
label="Generated Text:", |
|
|
lines=10, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
load_btn.click( |
|
|
fn=load_model, |
|
|
outputs=status |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_text, |
|
|
inputs=[prompt_input, max_length, temperature, top_p], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=clear_text, |
|
|
outputs=[prompt_input, output] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["The future of artificial intelligence is"], |
|
|
["In a world where technology and nature coexist"], |
|
|
["Write a short story about a robot who"], |
|
|
["Explain quantum computing in simple terms:"], |
|
|
], |
|
|
inputs=prompt_input |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |