File size: 3,485 Bytes
985eabb
1e235cc
985eabb
1e235cc
cc1b568
8b8d0cf
 
 
1e235cc
 
8b8d0cf
 
 
 
1e235cc
1f7ba92
02a0e92
1f7ba92
 
02a0e92
1f7ba92
 
1e235cc
 
 
 
1f7ba92
 
 
 
 
02a0e92
 
1e235cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3f66be
 
 
 
1e235cc
b3f66be
 
 
 
 
 
1e235cc
 
1f7ba92
b3f66be
 
 
 
 
 
 
 
 
 
 
d0ce6f0
b3f66be
 
 
 
 
 
1e235cc
b3f66be
1f7ba92
1e235cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the model and tokenizer
model_name = "akjindal53244/Llama-3.1-Storm-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

@spaces.GPU(duration=120)
def generate_text(prompt, max_length, temperature):
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]
    formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_length,
        do_sample=True,
        temperature=temperature,
        top_k=100,
        top_p=0.95,
    )
    
    return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

css = """
body {
    background: linear-gradient(135deg, #f5f7fa, #c3cfe2);
    font-family: Arial, sans-serif;
}
#custom-header {
    text-align: center;
    background: rgba(255, 255, 255, 0.8);
    padding: 20px;
    border-radius: 10px;
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
    position: relative;
    max-width: 800px;
    margin: 20px auto;
}
#custom-header h1 {
    color: #4A90E2;
    font-size: 2em;
    margin-bottom: 10px;
}
.llama-image {
    position: relative;
    transition: transform 0.3s;
    display: inline-block;
    margin-top: 20px;
}
.llama-image:hover {
    transform: scale(1.05);
}
.llama-image img {
    width: 200px;
    border-radius: 10px;
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.llama-description {
    position: absolute;
    bottom: -30px;
    left: 50%;
    transform: translateX(-50%);
    background-color: #4A90E2;
    color: white;
    padding: 5px 10px;
    border-radius: 5px;
    opacity: 0;
    transition: opacity 0.3s;
    white-space: nowrap;
}
.llama-image:hover .llama-description {
    opacity: 1;
}
.gradio-container {
    max-width: 900px !important;
    margin: auto;
    padding-top: 1.5rem;
}
.container {
    background-color: #ffffff;
    border-radius: 10px;
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
    padding: 20px;
    margin-top: 20px;
}
"""

with gr.Blocks(css=css) as iface:
    gr.HTML("""
        <div id="custom-header">
            <h1>Llama-3.1-Storm-8B Text Generation</h1>
            <p>Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt and let the AI create!</p>
            <div class="llama-image">
                <img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama">
                <div class="llama-description">Llama-3.1-Storm-8B Model</div>
            </div>
        </div>
    """)
    
    with gr.Column(elem_classes="container"):
        prompt = gr.Textbox(lines=5, label="Prompt")
        max_length = gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length")
        temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
        submit_btn = gr.Button("Generate", variant="primary")
        output = gr.Textbox(lines=10, label="Generated Text")
    
    submit_btn.click(generate_text, inputs=[prompt, max_length, temperature], outputs=output)

iface.launch()