File size: 4,972 Bytes
bb8930c
d3c793f
bb8930c
 
 
 
d3c793f
bb8930c
 
 
d3c793f
bb8930c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3c793f
bb8930c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3c793f
bb8930c
 
 
d3c793f
bb8930c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3c793f
 
8413b35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# app.py - SmallLM Gradio Demo
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")

# Global variables for model and tokenizer
model = None
tokenizer = None

def load_model():
    """Load the SmallLM model and tokenizer"""
    global model, tokenizer
    
    try:
        print("Loading SmallLM model...")
        model_name = "XsoraS/SmallLM"
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Add padding token if it doesn't exist
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Load model
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None,
            trust_remote_code=True
        )
        
        print("Model loaded successfully!")
        return "Model loaded successfully!"
        
    except Exception as e:
        error_msg = f"Error loading model: {str(e)}"
        print(error_msg)
        return error_msg

def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9):
    """Generate text using the loaded model"""
    global model, tokenizer
    
    if model is None or tokenizer is None:
        return "Please load the model first!"
    
    try:
        # Tokenize input
        inputs = tokenizer.encode(prompt, return_tensors="pt")
        
        # Move to same device as model
        if torch.cuda.is_available():
            inputs = inputs.to(model.device)
        
        # Generate
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_length=max_length,
                temperature=temperature,
                top_p=top_p,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                num_return_sequences=1
            )
        
        # Decode output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Return only the new generated part
        return generated_text[len(prompt):].strip()
        
    except Exception as e:
        return f"Error generating text: {str(e)}"

def clear_text():
    """Clear the input and output"""
    return "", ""

# Create Gradio interface
with gr.Blocks(title="SmallLM Demo", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🤖 SmallLM Inference Demo")
    gr.Markdown("Simple demo for XsoraS/SmallLM text generation")
    
    with gr.Row():
        with gr.Column(scale=1):
            load_btn = gr.Button("🔄 Load Model", variant="primary")
            status = gr.Textbox(
                label="Status", 
                value="Click 'Load Model' to start",
                interactive=False
            )
    
    with gr.Row():
        with gr.Column(scale=2):
            prompt_input = gr.Textbox(
                label="Enter your prompt:",
                placeholder="Once upon a time...",
                lines=3
            )
            
            with gr.Row():
                max_length = gr.Slider(
                    label="Max Length",
                    minimum=10,
                    maximum=500,
                    value=100,
                    step=10
                )
                temperature = gr.Slider(
                    label="Temperature",
                    minimum=0.1,
                    maximum=2.0,
                    value=0.7,
                    step=0.1
                )
                top_p = gr.Slider(
                    label="Top P",
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.05
                )
            
            with gr.Row():
                generate_btn = gr.Button("✨ Generate", variant="primary")
                clear_btn = gr.Button("🗑️ Clear")
        
        with gr.Column(scale=2):
            output = gr.Textbox(
                label="Generated Text:",
                lines=10,
                interactive=False
            )
    
    # Event handlers
    load_btn.click(
        fn=load_model,
        outputs=status
    )
    
    generate_btn.click(
        fn=generate_text,
        inputs=[prompt_input, max_length, temperature, top_p],
        outputs=output
    )
    
    clear_btn.click(
        fn=clear_text,
        outputs=[prompt_input, output]
    )
    
    # Examples
    gr.Examples(
        examples=[
            ["The future of artificial intelligence is"],
            ["In a world where technology and nature coexist"],
            ["Write a short story about a robot who"],
            ["Explain quantum computing in simple terms:"],
        ],
        inputs=prompt_input
    )

if __name__ == "__main__":
    demo.launch()