AlpineLLM-App / app.py
Borzyszkowski
adjusted width
7b7f57d
""" A simple Gradio web app to interact with the AlpineLLM model """
import base64
import gradio as gr
import os
import shutil
import torch
from huggingface_hub import hf_hub_download
from config_util import Config
from demo_inference import AlpineLLMInference
from style import custom_css
HF_TOKEN = os.environ.get("HF_TOKEN", None)
def download_model(cfg):
""" Download the model weights from Hugging Face Hub """
model_path = hf_hub_download(
repo_id=cfg.repo_id,
filename=cfg.model_name,
token=HF_TOKEN,
cache_dir=cfg.cache_dir
)
return model_path
def image_to_base64_data_url(filepath: str) -> str:
""" Convert an image file to a Base64 data URL for embedding in HTML """
try:
ext = os.path.splitext(filepath)[1].lower()
mime_types = {".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", ".gif": "image/gif", ".webp": "image/webp", ".bmp": "image/bmp"}
mime_type = mime_types.get(ext, "image/jpeg")
with open(filepath, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return f"data:{mime_type};base64,{encoded_string}"
except Exception as e:
print(f"Error encoding image to Base64: {e}")
return ""
def start_app():
""" Start the web app via Gradio with custom layout """
GOOGLE_FONTS_URL = "<link href='https://fonts.googleapis.com/css2?family=Noto+Sans+SC:wght@400;700&display=swap' rel='stylesheet'>"
LOGO_IMAGE_PATH = "assets/background_round.png"
logo_data_url = image_to_base64_data_url(LOGO_IMAGE_PATH) if os.path.exists(LOGO_IMAGE_PATH) else ""
with gr.Blocks(head=GOOGLE_FONTS_URL, css=custom_css, theme=gr.themes.Soft()) as app:
gr.HTML("""
<div class="app-header">
<h1>AlpineLLM Live Demo</h1>
<p>
A domain-specific language model for alpine storytelling. <br>
Try asking about mountain adventures! 🏔️ <br>
<strong>Author:</strong> <a href="https://borzyszkowski.github.io/">Bartek Borzyszkowski</a>
</p>
</div>
""")
gr.HTML(f"""
<div class="app-header">
<img src="{logo_data_url}" alt="AlpineLLM" style="max-height:10%; width: auto; margin: 10px auto; display: block;">
</div>
<div class="quick-links">
<a href="https://github.com/Borzyszkowski/AlpineLLM" target="_blank">GitHub</a> | <a href="https://huggingface.co/Borzyszkowski/AlpineLLM-Tiny-10M-Base" target="_blank">Model Page</a>
</div>
<div class="notice">
<strong>Heads up:</strong> This space shows a free CPU-only demo of the model, so inference may take a few seconds. Text generation of the tiny model may lack full coherence due to its limited size and character-level tokenization. Consider using the source repository to load larger pretrained weights and run inference on a GPU.
</div>
<br>
""")
gr.Markdown("<h3> About AlpineLLM</h3>")
gr.Markdown(
"<p>"
"AlpineLLM-Tiny-10M-Base is a lightweight base language model with ~10.8 million trainable parameters. It was pre-trained from scratch on raw text corpora drawn primarily from public-domain literature on alpinism, including expedition narratives and climbing essays. <br><br>"
"This demo showcases the model's text generation capabilities within its specialized domain. Please note that AlpineLLM is a base model, and it has not been fine-tuned for downstream tasks such as summarization or dialogue. Its outputs reflect patterns learned directly from the training texts. <br><br>"
"</p>"
)
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(
lines=8,
label="Your alpine prompt...",
placeholder="A dawn climb on the Matterhorn..."
)
max_tokens = gr.Slider(50, 1000, value=300, step=10, label="Max output tokens")
generate_btn = gr.Button("⛏️ Generate")
with gr.Column(scale=2):
output = gr.Textbox(lines=15, label="Generated Alpine Story", interactive=False)
gr.Markdown("<br>")
# Bind button click to inference
generate_btn.click(
fn=inference.generate_text,
inputs=[prompt, max_tokens],
outputs=output
)
app.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == '__main__':
os.chdir(os.path.dirname(os.path.abspath(__file__)))
# Define the configuration
cfg = {
'cuda_id': 0,
'model_type': 'transformer',
'repo_id': "Borzyszkowski/AlpineLLM-Tiny-10M-Base",
'model_name': "best_model.pt",
'cache_dir': "./model-cache",
}
cfg = Config(cfg)
# Define the hyperparameters
hyperparam_cfg={
"embedding_dim": 384,
"num_heads": 6,
"num_layers": 6,
"dropout": 0.2,
"context_len": 256,
"lr": 3e-4,
}
hyperparam_cfg = Config(hyperparam_cfg)
# Ensure model weights are available
shutil.rmtree(cfg.cache_dir, ignore_errors=True)
cfg.load_weights_path = download_model(cfg)
# Start the application
inference = AlpineLLMInference(cfg, hyperparam_cfg)
start_app()