| import gradio as gr |
| import torch |
| from gradio.themes.utils import sizes |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| import utils |
| from constants import END_OF_TEXT |
| from settings import DEFAULT_PORT |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| "BEE-spoke-data/smol_llama-101M-GQA-python", |
| use_fast=False, |
| ) |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
| tokenizer.pad_token = END_OF_TEXT |
| model = AutoModelForCausalLM.from_pretrained( |
| "BEE-spoke-data/smol_llama-101M-GQA-python", |
| device_map="auto", |
| ) |
| model = torch.compile(model, mode="reduce-overhead") |
|
|
| |
|
|
| _styles = utils.get_file_as_string("styles.css") |
|
|
| |
| readme_file_content = utils.get_file_as_string("README.md", path="./") |
| ( |
| manifest, |
| description, |
| disclaimer, |
| base_model_info, |
| formats, |
| ) = utils.get_sections(readme_file_content, "---", up_to=5) |
|
|
| theme = gr.themes.Soft( |
| primary_hue="yellow", |
| secondary_hue="orange", |
| neutral_hue="slate", |
| radius_size=sizes.radius_sm, |
| font=[ |
| gr.themes.GoogleFont("IBM Plex Sans", [400, 600]), |
| "ui-sans-serif", |
| "system-ui", |
| "sans-serif", |
| ], |
| text_size=sizes.text_lg, |
| ) |
|
|
|
|
| def run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty): |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| min_new_tokens=8, |
| renormalize_logits=True, |
| no_repeat_ngram_size=6, |
| repetition_penalty=repetition_penalty, |
| num_beams=3, |
| early_stopping=True, |
| do_sample=True, |
| temperature=temperature, |
| top_p=top_p, |
| ) |
| text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
| return text |
|
|
|
|
| |
| def gradio_interface( |
| prompt: str, |
| temperature: float, |
| max_new_tokens: int, |
| top_p: float, |
| repetition_penalty: float, |
| ): |
| return run_inference(prompt, temperature, max_new_tokens, top_p, repetition_penalty) |
|
|
|
|
| import random |
|
|
| examples = [ |
| ["def add_numbers(a, b):\n return", 0.2, 192, 0.9, 1.2], |
| [ |
| "class Car:\n def __init__(self, make, model):\n self.make = make\n self.model = model\n\n def display_car(self):", |
| 0.2, |
| 192, |
| 0.9, |
| 1.2, |
| ], |
| [ |
| "import pandas as pd\ndata = {'Name': ['Tom', 'Nick', 'John'], 'Age': [20, 21, 19]}\ndf = pd.DataFrame(data).convert_dtypes()\n# eda", |
| 0.2, |
| 192, |
| 0.9, |
| 1.2, |
| ], |
| [ |
| "def factorial(n):\n if n == 0:\n return 1\n else:", |
| 0.2, |
| 192, |
| 0.9, |
| 1.2, |
| ], |
| [ |
| 'def fibonacci(n):\n if n <= 0:\n raise ValueError("Incorrect input")\n elif n == 1:\n return 0\n elif n == 2:\n return 1\n else:', |
| 0.2, |
| 192, |
| 0.9, |
| 1.2, |
| ], |
| [ |
| "import matplotlib.pyplot as plt\nimport numpy as np\nx = np.linspace(0, 10, 100)\n# simple plot", |
| 0.2, |
| 192, |
| 0.9, |
| 1.2, |
| ], |
| ["def reverse_string(s:str) -> str:\n return", 0.2, 192, 0.9, 1.2], |
| ["def is_palindrome(word:str) -> bool:\n return", 0.2, 192, 0.9, 1.2], |
| [ |
| "def bubble_sort(lst: list):\n n = len(lst)\n for i in range(n):\n for j in range(0, n-i-1):", |
| 0.2, |
| 192, |
| 0.9, |
| 1.2, |
| ], |
| [ |
| "def binary_search(arr, low, high, x):\n if high >= low:\n mid = (high + low) // 2\n if arr[mid] == x:\n return mid\n elif arr[mid] > x:", |
| 0.2, |
| 192, |
| 0.9, |
| 1.2, |
| ], |
| ] |
|
|
| |
| with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo: |
| with gr.Column(): |
| gr.Markdown(description) |
| with gr.Row(): |
| with gr.Column(): |
| instruction = gr.Textbox( |
| value=random.choice([e[0] for e in examples]), |
| placeholder="Enter your code here", |
| label="Code", |
| elem_id="q-input", |
| ) |
| submit = gr.Button("Generate", variant="primary") |
| output = gr.Code(elem_id="q-output", language="python", lines=10) |
| with gr.Row(): |
| with gr.Column(): |
| with gr.Accordion("Advanced settings", open=False): |
| with gr.Row(): |
| column_1, column_2 = gr.Column(), gr.Column() |
| with column_1: |
| temperature = gr.Slider( |
| label="Temperature", |
| value=0.2, |
| minimum=0.0, |
| maximum=1.0, |
| step=0.05, |
| interactive=True, |
| info="Higher values produce more diverse outputs", |
| ) |
| max_new_tokens = gr.Slider( |
| label="Max new tokens", |
| value=128, |
| minimum=0, |
| maximum=512, |
| step=64, |
| interactive=True, |
| info="Number of tokens to generate", |
| ) |
| with column_2: |
| top_p = gr.Slider( |
| label="Top-p (nucleus sampling)", |
| value=0.90, |
| minimum=0.0, |
| maximum=1, |
| step=0.05, |
| interactive=True, |
| info="Higher values sample more low-probability tokens", |
| ) |
| repetition_penalty = gr.Slider( |
| label="Repetition penalty", |
| value=1.1, |
| minimum=1.0, |
| maximum=2.0, |
| step=0.05, |
| interactive=True, |
| info="Penalize repeated tokens", |
| ) |
| with gr.Column(): |
| version = gr.Dropdown( |
| [ |
| "smol_llama-101M-GQA-python", |
| ], |
| value="smol_llama-101M-GQA-python", |
| label="Version", |
| info="", |
| ) |
| gr.Markdown(disclaimer) |
| gr.Examples( |
| examples=examples, |
| inputs=[ |
| instruction, |
| temperature, |
| max_new_tokens, |
| top_p, |
| repetition_penalty, |
| version, |
| ], |
| cache_examples=False, |
| fn=gradio_interface, |
| outputs=[output], |
| ) |
| gr.Markdown(base_model_info) |
| gr.Markdown(formats) |
|
|
| submit.click( |
| gradio_interface, |
| inputs=[ |
| instruction, |
| temperature, |
| max_new_tokens, |
| top_p, |
| repetition_penalty, |
| ], |
| outputs=[output], |
| |
| max_batch_size=2, |
| show_progress=True, |
| ) |
|
|
| demo.queue(max_size=10).launch( |
| debug=True, |
| server_port=DEFAULT_PORT, |
| max_threads=utils.get_workers(), |
| ) |
|
|