smallm / app.py
HAMMALE's picture
Update app.py
8413b35 verified
# app.py - SmallLM Gradio Demo
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")
# Global variables for model and tokenizer
model = None
tokenizer = None
def load_model():
"""Load the SmallLM model and tokenizer"""
global model, tokenizer
try:
print("Loading SmallLM model...")
model_name = "XsoraS/SmallLM"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
print("Model loaded successfully!")
return "Model loaded successfully!"
except Exception as e:
error_msg = f"Error loading model: {str(e)}"
print(error_msg)
return error_msg
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9):
"""Generate text using the loaded model"""
global model, tokenizer
if model is None or tokenizer is None:
return "Please load the model first!"
try:
# Tokenize input
inputs = tokenizer.encode(prompt, return_tensors="pt")
# Move to same device as model
if torch.cuda.is_available():
inputs = inputs.to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
# Decode output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Return only the new generated part
return generated_text[len(prompt):].strip()
except Exception as e:
return f"Error generating text: {str(e)}"
def clear_text():
"""Clear the input and output"""
return "", ""
# Create Gradio interface
with gr.Blocks(title="SmallLM Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ€– SmallLM Inference Demo")
gr.Markdown("Simple demo for XsoraS/SmallLM text generation")
with gr.Row():
with gr.Column(scale=1):
load_btn = gr.Button("πŸ”„ Load Model", variant="primary")
status = gr.Textbox(
label="Status",
value="Click 'Load Model' to start",
interactive=False
)
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Enter your prompt:",
placeholder="Once upon a time...",
lines=3
)
with gr.Row():
max_length = gr.Slider(
label="Max Length",
minimum=10,
maximum=500,
value=100,
step=10
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1
)
top_p = gr.Slider(
label="Top P",
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05
)
with gr.Row():
generate_btn = gr.Button("✨ Generate", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear")
with gr.Column(scale=2):
output = gr.Textbox(
label="Generated Text:",
lines=10,
interactive=False
)
# Event handlers
load_btn.click(
fn=load_model,
outputs=status
)
generate_btn.click(
fn=generate_text,
inputs=[prompt_input, max_length, temperature, top_p],
outputs=output
)
clear_btn.click(
fn=clear_text,
outputs=[prompt_input, output]
)
# Examples
gr.Examples(
examples=[
["The future of artificial intelligence is"],
["In a world where technology and nature coexist"],
["Write a short story about a robot who"],
["Explain quantum computing in simple terms:"],
],
inputs=prompt_input
)
if __name__ == "__main__":
demo.launch()