FlameF0X commited on
Commit
03760bf
·
verified ·
1 Parent(s): b5872a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -39
app.py CHANGED
@@ -1,43 +1,83 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM # or your model class
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # load tokenizer & model
6
- model_name = "FlameF0X/i3-80m" # replace with correct HF model path
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  model.eval()
10
- if torch.cuda.is_available():
11
- model = model.cuda()
12
-
13
- def generate(prompt: str, max_new_tokens: int = 100, temperature: float = 1.0, top_k: int = None):
14
- inputs = tokenizer(prompt, return_tensors="pt")
15
- input_ids = inputs["input_ids"]
16
- if torch.cuda.is_available():
17
- input_ids = input_ids.cuda()
18
  with torch.no_grad():
19
- output_ids = model.generate(
20
- input_ids,
21
- max_new_tokens=max_new_tokens,
22
- temperature=temperature,
23
- top_k=top_k,
24
- do_sample=True
25
- )
26
- output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
27
- return output
28
-
29
- # Gradio interface
30
- iface = gr.Interface(
31
- fn=generate,
32
- inputs=[
33
- gr.Textbox(label="Prompt", lines=2, placeholder="Enter prompt here..."),
34
- gr.Slider(label="Max new tokens", minimum=1, maximum=500, step=1, value=100),
35
- gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0),
36
- gr.Slider(label="Top-k (0 = disabled)", minimum=0, maximum=200, step=1, value=40)
37
- ],
38
- outputs=gr.Textbox(label="Generated Text"),
39
- title="i3-80m Generation Demo",
40
- description="Interact with the i3 hybrid-architecture model."
41
- )
42
-
43
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import json
4
+ from safetensors.torch import load_file as safe_load
5
+ from huggingface_hub import hf_hub_download
6
+ from app_classes import i3Model, ChunkTokenizer # Make sure your classes file is importable
7
+
8
+ # ------------------------------
9
+ # Hugging Face Repo & Files
10
+ # ------------------------------
11
+ REPO_ID = "FlameF0X/i3-80m" # Replace with your HF repo
12
+
13
+ print("Downloading model files from Hugging Face...")
14
+ model_file = hf_hub_download(REPO_ID, "model.safetensors")
15
+ vocab_file = hf_hub_download(REPO_ID, "chunk_vocab_combined.json")
16
+ config_file = hf_hub_download(REPO_ID, "config.json")
17
+
18
+ # ------------------------------
19
+ # Load Config
20
+ # ------------------------------
21
+ with open(config_file, "r") as f:
22
+ config = json.load(f)
23
 
24
+ # ------------------------------
25
+ # Load Tokenizer
26
+ # ------------------------------
27
+ tokenizer = ChunkTokenizer()
28
+ tokenizer.load(vocab_file)
29
+
30
+ # ------------------------------
31
+ # Initialize Model
32
+ # ------------------------------
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ model = i3Model(vocab_size=tokenizer.vocab_size,
35
+ d_model=config.get("d_model", 512),
36
+ n_heads=config.get("n_heads", 16),
37
+ max_seq_len=config.get("max_seq_len", 512),
38
+ d_state=config.get("d_state", 32)).to(device)
39
+
40
+ # Load weights
41
+ state_dict = safe_load(model_file, device=device)
42
+ model.load_state_dict(state_dict)
43
  model.eval()
44
+
45
+ # ------------------------------
46
+ # Generation Function
47
+ # ------------------------------
48
+ def generate_text(prompt, max_tokens=100, temperature=1.0, top_k=40):
49
+ idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(device)
 
 
50
  with torch.no_grad():
51
+ out_idx = model.generate(idx, max_new_tokens=int(max_tokens),
52
+ temperature=float(temperature),
53
+ top_k=int(top_k))
54
+ return tokenizer.decode(out_idx[0].cpu())
55
+
56
+ # ------------------------------
57
+ # Gradio UI
58
+ # ------------------------------
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("## i3 Model Text Generator")
61
+
62
+ with gr.Row():
63
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Type your text here...", lines=3)
64
+ generate_btn = gr.Button("Generate")
65
+
66
+ output_box = gr.Textbox(label="Generated Text", lines=10)
67
+
68
+ with gr.Accordion("Dev Panel", open=False):
69
+ max_tokens_input = gr.Slider(10, 500, value=100, label="Max Tokens")
70
+ temperature_input = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Temperature")
71
+ top_k_input = gr.Slider(1, tokenizer.vocab_size, value=40, step=1, label="Top-k Sampling")
72
+
73
+ # Connect button
74
+ generate_btn.click(
75
+ generate_text,
76
+ inputs=[prompt_input, max_tokens_input, temperature_input, top_k_input],
77
+ outputs=[output_box]
78
+ )
79
+
80
+ # ------------------------------
81
+ # Launch App
82
+ # ------------------------------
83
+ demo.launch(share=True)