|
|
import gradio as gr |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "OpceanAI/Yuuki-best" |
|
|
MODEL_LOADED = False |
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
global model, tokenizer, MODEL_LOADED |
|
|
|
|
|
if MODEL_LOADED: |
|
|
return True |
|
|
|
|
|
try: |
|
|
print("Loading Yuuki model...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float32, |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
MODEL_LOADED = True |
|
|
print("Model loaded successfully!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def generate_code( |
|
|
prompt: str, |
|
|
max_new_tokens: int = 100, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.9, |
|
|
top_k: int = 50, |
|
|
repetition_penalty: float = 1.1 |
|
|
) -> str: |
|
|
|
|
|
if not MODEL_LOADED: |
|
|
if not load_model(): |
|
|
return "Error: Model failed to load. Please try refreshing the page." |
|
|
|
|
|
if not prompt or not prompt.strip(): |
|
|
return "Please enter a code prompt." |
|
|
|
|
|
try: |
|
|
inputs = tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
repetition_penalty=repetition_penalty, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
num_return_sequences=1 |
|
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return generated_text |
|
|
|
|
|
except Exception as e: |
|
|
return f"Generation error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CUSTOM_CSS = """ |
|
|
.gradio-container { |
|
|
max-width: 100% !important; |
|
|
padding: 0 !important; |
|
|
margin: 0 !important; |
|
|
background: #0a0a0a !important; |
|
|
min-height: 100vh; |
|
|
} |
|
|
|
|
|
.main { |
|
|
background: #0a0a0a !important; |
|
|
} |
|
|
|
|
|
footer { |
|
|
display: none !important; |
|
|
} |
|
|
|
|
|
#header { |
|
|
display: flex; |
|
|
align-items: center; |
|
|
justify-content: space-between; |
|
|
padding: 16px 24px; |
|
|
border-bottom: 1px solid #1f1f1f; |
|
|
background: #0a0a0a; |
|
|
} |
|
|
|
|
|
#logo { |
|
|
font-size: 1.25rem; |
|
|
font-weight: 600; |
|
|
color: #fafafa; |
|
|
} |
|
|
|
|
|
#version-tag { |
|
|
color: #666; |
|
|
font-weight: 400; |
|
|
font-size: 0.875rem; |
|
|
margin-left: 8px; |
|
|
} |
|
|
|
|
|
#chat-container { |
|
|
max-width: 800px; |
|
|
margin: 0 auto; |
|
|
padding: 40px 24px; |
|
|
} |
|
|
|
|
|
#welcome-box { |
|
|
text-align: center; |
|
|
margin-bottom: 32px; |
|
|
} |
|
|
|
|
|
#welcome-title { |
|
|
font-size: 2rem; |
|
|
font-weight: 600; |
|
|
color: #fafafa; |
|
|
margin-bottom: 8px; |
|
|
} |
|
|
|
|
|
#welcome-subtitle { |
|
|
font-size: 1rem; |
|
|
color: #666; |
|
|
margin-bottom: 16px; |
|
|
} |
|
|
|
|
|
#disclaimer { |
|
|
background: #18181b; |
|
|
border: 1px solid #27272a; |
|
|
border-radius: 8px; |
|
|
padding: 12px 16px; |
|
|
color: #a1a1a1; |
|
|
font-size: 0.8rem; |
|
|
text-align: left; |
|
|
display: inline-block; |
|
|
max-width: 600px; |
|
|
} |
|
|
|
|
|
#output-box textarea { |
|
|
background: #141414 !important; |
|
|
color: #e5e5e5 !important; |
|
|
font-family: monospace !important; |
|
|
font-size: 0.875rem !important; |
|
|
border: 1px solid #262626 !important; |
|
|
border-radius: 12px !important; |
|
|
} |
|
|
|
|
|
#input-box textarea { |
|
|
background: #141414 !important; |
|
|
color: #fafafa !important; |
|
|
border: 1px solid #262626 !important; |
|
|
border-radius: 12px !important; |
|
|
} |
|
|
|
|
|
#input-box textarea::placeholder { |
|
|
color: #525252 !important; |
|
|
} |
|
|
|
|
|
#generate-btn { |
|
|
background: #fafafa !important; |
|
|
color: #0a0a0a !important; |
|
|
border: none !important; |
|
|
border-radius: 8px !important; |
|
|
font-weight: 600 !important; |
|
|
} |
|
|
|
|
|
#generate-btn:hover { |
|
|
background: #e5e5e5 !important; |
|
|
} |
|
|
|
|
|
#examples-label { |
|
|
color: #525252; |
|
|
font-size: 0.75rem; |
|
|
text-transform: uppercase; |
|
|
letter-spacing: 0.05em; |
|
|
margin-bottom: 12px; |
|
|
margin-top: 16px; |
|
|
} |
|
|
|
|
|
.example-btn { |
|
|
background: #141414 !important; |
|
|
border: 1px solid #262626 !important; |
|
|
color: #a1a1a1 !important; |
|
|
font-family: monospace !important; |
|
|
border-radius: 8px !important; |
|
|
} |
|
|
|
|
|
.example-btn:hover { |
|
|
background: #1f1f1f !important; |
|
|
border-color: #404040 !important; |
|
|
color: #fafafa !important; |
|
|
} |
|
|
|
|
|
.panel-section { |
|
|
background: #141414; |
|
|
border: 1px solid #262626; |
|
|
border-radius: 12px; |
|
|
padding: 24px; |
|
|
margin-bottom: 16px; |
|
|
} |
|
|
|
|
|
.panel-title { |
|
|
font-size: 0.875rem; |
|
|
font-weight: 600; |
|
|
color: #fafafa; |
|
|
margin-bottom: 16px; |
|
|
} |
|
|
|
|
|
.info-row { |
|
|
display: flex; |
|
|
justify-content: space-between; |
|
|
padding: 12px 0; |
|
|
border-bottom: 1px solid #1f1f1f; |
|
|
} |
|
|
|
|
|
.info-row:last-child { |
|
|
border-bottom: none; |
|
|
} |
|
|
|
|
|
.info-label { |
|
|
color: #666; |
|
|
font-size: 0.875rem; |
|
|
} |
|
|
|
|
|
.info-value { |
|
|
color: #fafafa; |
|
|
font-size: 0.875rem; |
|
|
font-weight: 500; |
|
|
} |
|
|
|
|
|
.score-grid { |
|
|
display: flex; |
|
|
gap: 8px; |
|
|
flex-wrap: wrap; |
|
|
} |
|
|
|
|
|
.score-badge { |
|
|
padding: 6px 12px; |
|
|
border-radius: 6px; |
|
|
font-size: 0.75rem; |
|
|
font-weight: 600; |
|
|
} |
|
|
|
|
|
.score-good { |
|
|
background: rgba(34, 197, 94, 0.15); |
|
|
color: #22c55e; |
|
|
border: 1px solid rgba(34, 197, 94, 0.3); |
|
|
} |
|
|
|
|
|
.score-medium { |
|
|
background: rgba(234, 179, 8, 0.15); |
|
|
color: #eab308; |
|
|
border: 1px solid rgba(234, 179, 8, 0.3); |
|
|
} |
|
|
|
|
|
.score-weak { |
|
|
background: rgba(239, 68, 68, 0.15); |
|
|
color: #ef4444; |
|
|
border: 1px solid rgba(239, 68, 68, 0.3); |
|
|
} |
|
|
|
|
|
.comparison-table { |
|
|
width: 100%; |
|
|
border-collapse: collapse; |
|
|
font-size: 0.875rem; |
|
|
} |
|
|
|
|
|
.comparison-table th, |
|
|
.comparison-table td { |
|
|
padding: 12px; |
|
|
text-align: left; |
|
|
border-bottom: 1px solid #1f1f1f; |
|
|
} |
|
|
|
|
|
.comparison-table th { |
|
|
color: #666; |
|
|
font-weight: 500; |
|
|
font-size: 0.75rem; |
|
|
text-transform: uppercase; |
|
|
} |
|
|
|
|
|
.comparison-table td { |
|
|
color: #a1a1a1; |
|
|
} |
|
|
|
|
|
.comparison-table strong { |
|
|
color: #22c55e; |
|
|
} |
|
|
|
|
|
.links-grid { |
|
|
display: flex; |
|
|
gap: 16px; |
|
|
flex-wrap: wrap; |
|
|
} |
|
|
|
|
|
.link-item { |
|
|
color: #a1a1a1; |
|
|
text-decoration: none; |
|
|
font-size: 0.875rem; |
|
|
} |
|
|
|
|
|
.link-item:hover { |
|
|
color: #fafafa; |
|
|
} |
|
|
|
|
|
.gr-tab-nav button { |
|
|
background: transparent !important; |
|
|
border: none !important; |
|
|
color: #666 !important; |
|
|
} |
|
|
|
|
|
.gr-tab-nav button.selected { |
|
|
color: #fafafa !important; |
|
|
border-bottom: 2px solid #fafafa !important; |
|
|
} |
|
|
|
|
|
.gr-prose { |
|
|
color: #a1a1a1 !important; |
|
|
} |
|
|
|
|
|
.gr-prose strong { |
|
|
color: #fafafa !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
css=CUSTOM_CSS, |
|
|
title="Yuuki", |
|
|
theme=gr.themes.Base( |
|
|
primary_hue="neutral", |
|
|
secondary_hue="neutral", |
|
|
neutral_hue="neutral", |
|
|
).set( |
|
|
body_background_fill="#0a0a0a", |
|
|
body_background_fill_dark="#0a0a0a", |
|
|
block_background_fill="#141414", |
|
|
block_background_fill_dark="#141414", |
|
|
block_border_color="#262626", |
|
|
block_border_color_dark="#262626", |
|
|
body_text_color="#a1a1a1", |
|
|
body_text_color_dark="#a1a1a1", |
|
|
input_background_fill="#141414", |
|
|
input_background_fill_dark="#141414", |
|
|
) |
|
|
) as demo: |
|
|
|
|
|
|
|
|
gr.HTML('<div id="header"><div id="logo">Yuuki <span id="version-tag">v0.1-preview</span></div></div>') |
|
|
|
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
|
|
|
with gr.Tab("Chat"): |
|
|
gr.HTML('<div id="chat-container">') |
|
|
gr.HTML('<div id="welcome-box"><div id="welcome-title">Yuuki</div><div id="welcome-subtitle">Mobile-trained code generation model</div><div id="disclaimer"><strong>Experimental model.</strong> Best at Agda (55/100). Limited C, Assembly. Weak Python. Trained on smartphone CPU.</div></div>') |
|
|
|
|
|
output = gr.Textbox( |
|
|
label="Output", |
|
|
lines=10, |
|
|
show_copy_button=True, |
|
|
elem_id="output-box", |
|
|
placeholder="Generated code will appear here..." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
prompt_input = gr.Textbox( |
|
|
label="", |
|
|
placeholder="Enter code prompt... (e.g., module Main where)", |
|
|
lines=2, |
|
|
elem_id="input-box", |
|
|
show_label=False, |
|
|
scale=4 |
|
|
) |
|
|
generate_btn = gr.Button("Generate", variant="primary", elem_id="generate-btn", scale=1) |
|
|
|
|
|
gr.HTML('<div id="examples-label">Try these</div>') |
|
|
with gr.Row(): |
|
|
ex1 = gr.Button("module Main where", elem_classes=["example-btn"], size="sm") |
|
|
ex2 = gr.Button("open import Data.Nat", elem_classes=["example-btn"], size="sm") |
|
|
ex3 = gr.Button("int main() {", elem_classes=["example-btn"], size="sm") |
|
|
ex4 = gr.Button("def hello():", elem_classes=["example-btn"], size="sm") |
|
|
|
|
|
gr.HTML('</div>') |
|
|
|
|
|
|
|
|
with gr.Tab("Settings"): |
|
|
with gr.Column(elem_classes=["panel-section"]): |
|
|
gr.HTML('<div class="panel-title">Generation Parameters</div>') |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
max_new_tokens = gr.Slider(minimum=20, maximum=256, value=100, step=10, label="Max Tokens") |
|
|
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature") |
|
|
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P") |
|
|
with gr.Column(): |
|
|
top_k = gr.Slider(minimum=1, maximum=100, value=50, step=5, label="Top K") |
|
|
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty") |
|
|
|
|
|
|
|
|
with gr.Tab("Info"): |
|
|
with gr.Column(elem_classes=["panel-section"]): |
|
|
gr.HTML('<div class="panel-title">Model Information</div>') |
|
|
gr.HTML('<div class="info-row"><span class="info-label">Model</span><span class="info-value">Yuuki-best (checkpoint-2000)</span></div>') |
|
|
gr.HTML('<div class="info-row"><span class="info-label">Size</span><span class="info-value">988 MB</span></div>') |
|
|
gr.HTML('<div class="info-row"><span class="info-label">Training Progress</span><span class="info-value">2,000 / 37,500 steps (5.3%)</span></div>') |
|
|
gr.HTML('<div class="info-row"><span class="info-label">Hardware</span><span class="info-value">Snapdragon 685 (CPU only)</span></div>') |
|
|
gr.HTML('<div class="info-row"><span class="info-label">Speed</span><span class="info-value">~86 sec/step</span></div>') |
|
|
gr.HTML('<div class="info-row"><span class="info-label">Cost</span><span class="info-value">$0.00</span></div>') |
|
|
|
|
|
with gr.Column(elem_classes=["panel-section"]): |
|
|
gr.HTML('<div class="panel-title">Language Performance</div>') |
|
|
gr.HTML('<div class="score-grid"><span class="score-badge score-good">Agda: 55/100</span><span class="score-badge score-medium">C: 20/100</span><span class="score-badge score-medium">Assembly: 15/100</span><span class="score-badge score-weak">Python: 8/100</span></div>') |
|
|
gr.HTML('<p style="color: #666; font-size: 0.8rem; margin-top: 16px;">Average quality: 24.6/100 (+146% from checkpoint 1400)</p>') |
|
|
|
|
|
with gr.Column(elem_classes=["panel-section"]): |
|
|
gr.HTML('<div class="panel-title">Checkpoint Comparison</div>') |
|
|
gr.HTML('<table class="comparison-table"><thead><tr><th>Metric</th><th>CP-1400</th><th>CP-2000</th></tr></thead><tbody><tr><td>Progress</td><td>3.7%</td><td><strong>5.3%</strong></td></tr><tr><td>Agda</td><td>20/100</td><td><strong>55/100</strong></td></tr><tr><td>C</td><td>8/100</td><td><strong>20/100</strong></td></tr><tr><td>Assembly</td><td>2/100</td><td><strong>15/100</strong></td></tr><tr><td>Average</td><td>~10/100</td><td><strong>24.6/100</strong></td></tr></tbody></table>') |
|
|
|
|
|
with gr.Column(elem_classes=["panel-section"]): |
|
|
gr.HTML('<div class="panel-title">About</div>') |
|
|
gr.Markdown("This is the **best model available at this moment**. The full **v0.1** release is coming soon. Once published, plans for **v0.2** will begin.\n\nYuuki is being trained **entirely on a smartphone CPU** by a **single person**. A research paper exploring mobile LLM training will be published soon.\n\n**Why this matters:**\n- Students without GPU access can experiment with ML\n- Democratizes ML research globally\n- Explores edge ML training possibilities") |
|
|
|
|
|
with gr.Column(elem_classes=["panel-section"]): |
|
|
gr.HTML('<div class="panel-title">Links</div>') |
|
|
gr.HTML('<div class="links-grid"><a href="https://huggingface.co/OpceanAI/Yuuki-best" target="_blank" class="link-item">Model Card</a><a href="https://huggingface.co/OpceanAI/Yuuki" target="_blank" class="link-item">Original Yuuki</a><a href="https://github.com/YuuKi-OS/yuuki-training" target="_blank" class="link-item">Training Code</a></div>') |
|
|
gr.HTML('<p style="color: #525252; font-size: 0.75rem; margin-top: 24px;">Licensed under Apache 2.0</p>') |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_code, |
|
|
inputs=[prompt_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
prompt_input.submit( |
|
|
fn=generate_code, |
|
|
inputs=[prompt_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
ex1.click(lambda: "module Main where", outputs=prompt_input) |
|
|
ex2.click(lambda: "open import Data.Nat", outputs=prompt_input) |
|
|
ex3.click(lambda: "int main() {", outputs=prompt_input) |
|
|
ex4.click(lambda: "def hello():", outputs=prompt_input) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=False, show_error=True, show_api=False) |
|
|
|
|
|
|