FlameF0X commited on
Commit
1a9f67b
Β·
verified Β·
1 Parent(s): 31ca2bf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ import gradio as gr
4
+ import json
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ # -------------------- DEVICE --------------------
8
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # -------------------- MODEL CONFIG --------------------
11
+ MODEL_NAME = "FlameF0X/i3-80m"
12
+ LOCAL_SAFETENSORS = Path("model.safetensors")
13
+ LOCAL_BIN = Path("pytorch_model.bin")
14
+ VOCAB_JSON = Path("chunk_vocab_combined.json")
15
+
16
+ # -------------------- LOAD VOCAB --------------------
17
+ with open(VOCAB_JSON, 'r') as f:
18
+ vocab_data = json.load(f)
19
+ VOCAB_SIZE = vocab_data["vocab_size"]
20
+
21
+ # -------------------- IMPORT YOUR MODEL CLASS --------------------
22
+ from app_classes import i3Model, ChunkTokenizer
23
+
24
+ tokenizer = ChunkTokenizer()
25
+ tokenizer.load(VOCAB_JSON)
26
+
27
+ model = i3Model(
28
+ vocab_size=VOCAB_SIZE,
29
+ d_model=512,
30
+ n_heads=16,
31
+ max_seq_len=256,
32
+ d_state=32
33
+ ).to(DEVICE)
34
+
35
+ # -------------------- LOAD WEIGHTS --------------------
36
+ try:
37
+ if LOCAL_SAFETENSORS.exists():
38
+ from safetensors.torch import load_file
39
+ state_dict = load_file(LOCAL_SAFETENSORS)
40
+ model.load_state_dict(state_dict)
41
+ print("βœ… Loaded weights from local safetensors")
42
+ elif LOCAL_BIN.exists():
43
+ state_dict = torch.load(LOCAL_BIN, map_location=DEVICE, weights_only=False)
44
+ model.load_state_dict(state_dict)
45
+ print("βœ… Loaded weights from local .bin")
46
+ else:
47
+ print("⚑ Downloading model from HuggingFace...")
48
+ bin_file = hf_hub_download(repo_id=MODEL_NAME, filename="pytorch_model.bin")
49
+ state_dict = torch.load(bin_file, map_location=DEVICE, weights_only=False)
50
+ model.load_state_dict(state_dict)
51
+ print("βœ… Loaded weights from HuggingFace")
52
+ except Exception as e:
53
+ raise RuntimeError(f"Failed to load model weights: {e}")
54
+
55
+ model.eval()
56
+
57
+ # -------------------- GENERATION FUNCTION --------------------
58
+ def generate_text(prompt, max_tokens=100, temperature=0.8, top_k=40):
59
+ if not prompt.strip():
60
+ yield "⚠️ Please enter a prompt to generate text."
61
+ return
62
+
63
+ try:
64
+ idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(DEVICE)
65
+
66
+ # Use the streaming method from the model
67
+ for out_idx in model.generate_stream(idx, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k):
68
+ # Decode the current sequence (cpu() is needed because tokens are on GPU)
69
+ current_text = tokenizer.decode(out_idx[0].cpu())
70
+ yield current_text
71
+
72
+ except Exception as e:
73
+ yield f"❌ Generation error: {str(e)}"
74
+
75
+ # -------------------- GRADIO UI --------------------
76
+ custom_css = """
77
+ .gradio-container {
78
+ max-width: 1200px !important;
79
+ }
80
+ .main-header {
81
+ text-align: center;
82
+ margin-bottom: 2rem;
83
+ }
84
+ .param-card {
85
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
86
+ padding: 1.5rem;
87
+ border-radius: 12px;
88
+ margin-bottom: 1rem;
89
+ }
90
+ """
91
+
92
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
93
+ # Header
94
+ with gr.Row():
95
+ gr.Markdown(
96
+ """
97
+ # πŸš€ i3-80M Text Generation
98
+ ### Powered by Mamba-based Architecture
99
+ Generate creative text using the i3-80M language model with customizable parameters.
100
+ """,
101
+ elem_classes="main-header"
102
+ )
103
+
104
+ # Main Generation Area
105
+ with gr.Row():
106
+ with gr.Column(scale=2):
107
+ prompt_input = gr.Textbox(
108
+ label="✍️ Enter Your Prompt",
109
+ placeholder="Once upon a time in a distant galaxy...",
110
+ lines=4,
111
+ max_lines=8
112
+ )
113
+
114
+ with gr.Accordion("βš™οΈ Generation Parameters", open=True):
115
+ with gr.Row():
116
+ max_tokens_input = gr.Slider(
117
+ 10, 500,
118
+ value=100,
119
+ step=10,
120
+ label="Max Tokens",
121
+ info="Maximum number of tokens to generate"
122
+ )
123
+ temp_input = gr.Slider(
124
+ 0.1, 2.0,
125
+ value=0.8,
126
+ step=0.05,
127
+ label="Temperature",
128
+ info="Higher = more creative, Lower = more focused"
129
+ )
130
+
131
+ topk_input = gr.Slider(
132
+ 1, 100,
133
+ value=40,
134
+ step=1,
135
+ label="Top-k Sampling",
136
+ info="Number of top tokens to consider"
137
+ )
138
+
139
+ with gr.Row():
140
+ generate_btn = gr.Button("🎨 Generate Text", variant="primary", size="lg")
141
+ clear_btn = gr.ClearButton(components=[prompt_input], value="πŸ—‘οΈ Clear", size="lg")
142
+
143
+ with gr.Column(scale=2):
144
+ output_text = gr.Textbox(
145
+ label="πŸ“ Generated Output",
146
+ lines=12,
147
+ max_lines=20,
148
+ show_copy_button=True
149
+ )
150
+
151
+ # Examples Section
152
+ with gr.Row():
153
+ gr.Examples(
154
+ examples=[
155
+ ["The future of artificial intelligence is", 150, 0.7, 50],
156
+ ["In a world where technology and nature coexist", 200, 0.9, 40],
157
+ ["The scientist discovered something remarkable", 120, 0.8, 45],
158
+ ],
159
+ inputs=[prompt_input, max_tokens_input, temp_input, topk_input],
160
+ label="πŸ’‘ Try These Examples"
161
+ )
162
+
163
+ # Developer Panel
164
+ with gr.Accordion("πŸ”§ Developer Info", open=False):
165
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
166
+
167
+ with gr.Row():
168
+ with gr.Column():
169
+ gr.Markdown(f"""
170
+ **Model Architecture:**
171
+ - **Model:** i3-80M
172
+ - **Device:** {DEVICE}
173
+ - **Vocab Size:** {VOCAB_SIZE:,}
174
+ - **Parameters:** {total_params:,} ({total_params/1e6:.2f}M)
175
+ """)
176
+
177
+ with gr.Column():
178
+ gr.Markdown(f"""
179
+ **Configuration:**
180
+ - **d_model:** 512
181
+ - **n_heads:** 16
182
+ - **max_seq_len:** 256
183
+ - **d_state:** 32
184
+ """)
185
+
186
+ # Footer
187
+ gr.Markdown(
188
+ """
189
+ ---
190
+ <div style="text-align: center; color: #666;">
191
+ <p>Built with ❀️ using Gradio | Model: FlameF0X/i3-80m</p>
192
+ </div>
193
+ """,
194
+ )
195
+
196
+ # Connect UI
197
+ generate_btn.click(
198
+ generate_text,
199
+ inputs=[prompt_input, max_tokens_input, temp_input, topk_input],
200
+ outputs=[output_text]
201
+ )
202
+
203
+ # -------------------- RUN --------------------
204
+ if __name__ == "__main__":
205
+ # queue() is generally required for streaming to work correctly
206
+ demo.queue().launch(share=False)