File size: 5,481 Bytes
bb2fa48
 
91284b8
244696f
bb2fa48
0f2b26a
bb2fa48
 
 
 
 
91284b8
 
bb2fa48
4275c77
 
bb2fa48
 
 
 
 
 
4275c77
0f2b26a
bb2fa48
 
 
 
91284b8
 
 
 
 
 
 
 
 
 
 
 
 
 
bb2fa48
 
91284b8
 
 
 
e8e2b5c
 
 
 
 
 
 
 
 
 
91284b8
d577564
 
7b7f57d
d577564
 
 
 
 
 
 
7c1dc68
d577564
 
7c1dc68
91284b8
 
 
bc7b6c9
bb2fa48
 
 
fefddfe
bb2fa48
 
 
 
 
 
91284b8
bb2fa48
 
91284b8
7c1dc68
bb2fa48
 
 
 
 
 
 
 
15984d7
bb2fa48
 
 
 
 
 
 
 
 
c2964c6
1184db0
0f2b26a
bb2fa48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f2b26a
bb2fa48
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
""" A simple Gradio web app to interact with the AlpineLLM model """

import base64
import gradio as gr
import os
import shutil
import torch

from huggingface_hub import hf_hub_download

from config_util import Config
from demo_inference import AlpineLLMInference
from style import custom_css

HF_TOKEN = os.environ.get("HF_TOKEN", None)


def download_model(cfg):
    """ Download the model weights from Hugging Face Hub """
    model_path = hf_hub_download(
        repo_id=cfg.repo_id, 
        filename=cfg.model_name, 
        token=HF_TOKEN,
        cache_dir=cfg.cache_dir
    )
    return model_path


def image_to_base64_data_url(filepath: str) -> str:
    """ Convert an image file to a Base64 data URL for embedding in HTML """
    try:
        ext = os.path.splitext(filepath)[1].lower()
        mime_types = {".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", ".gif": "image/gif", ".webp": "image/webp", ".bmp": "image/bmp"}
        mime_type = mime_types.get(ext, "image/jpeg")
        with open(filepath, "rb") as image_file:
            encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
        return f"data:{mime_type};base64,{encoded_string}"
    except Exception as e:
        print(f"Error encoding image to Base64: {e}")
        return ""


def start_app():
    """ Start the web app via Gradio with custom layout """
    GOOGLE_FONTS_URL = "<link href='https://fonts.googleapis.com/css2?family=Noto+Sans+SC:wght@400;700&display=swap' rel='stylesheet'>"
    LOGO_IMAGE_PATH = "assets/background_round.png"
    logo_data_url = image_to_base64_data_url(LOGO_IMAGE_PATH) if os.path.exists(LOGO_IMAGE_PATH) else ""
    with gr.Blocks(head=GOOGLE_FONTS_URL, css=custom_css, theme=gr.themes.Soft()) as app:
        gr.HTML("""
        <div class="app-header">
            <h1>AlpineLLM Live Demo</h1>
            <p>
                A domain-specific language model for alpine storytelling. <br>
                Try asking about mountain adventures! 🏔️ <br>
                <strong>Author:</strong> <a href="https://borzyszkowski.github.io/">Bartek Borzyszkowski</a>
            </p>
        </div>
        """)
        
        gr.HTML(f"""
            <div class="app-header">
                <img src="{logo_data_url}" alt="AlpineLLM" style="max-height:10%; width: auto; margin: 10px auto; display: block;">
            </div>
            <div class="quick-links">
                <a href="https://github.com/Borzyszkowski/AlpineLLM" target="_blank">GitHub</a> | <a href="https://huggingface.co/Borzyszkowski/AlpineLLM-Tiny-10M-Base" target="_blank">Model Page</a>
            </div>
            <div class="notice">
                <strong>Heads up:</strong> This space shows a free CPU-only demo of the model, so inference may take a few seconds. Text generation of the tiny model may lack full coherence due to its limited size and character-level tokenization. Consider using the source repository to load larger pretrained weights and run inference on a GPU. 
            </div>
            <br>
            """)

        gr.Markdown("<h3> About AlpineLLM</h3>")
        gr.Markdown(
            "<p>"
            "AlpineLLM-Tiny-10M-Base is a lightweight base language model with ~10.8 million trainable parameters. It was pre-trained from scratch on raw text corpora drawn primarily from public-domain literature on alpinism, including expedition narratives and climbing essays. <br><br>"
            "This demo showcases the model's text generation capabilities within its specialized domain. Please note that AlpineLLM is a base model, and it has not been fine-tuned for downstream tasks such as summarization or dialogue. Its outputs reflect patterns learned directly from the training texts. <br><br>"
            "</p>"
        )
        with gr.Row():
            with gr.Column(scale=2):
                prompt = gr.Textbox(
                    lines=8,
                    label="Your alpine prompt...",
                    placeholder="A dawn climb on the Matterhorn..."
                )
                max_tokens = gr.Slider(50, 1000, value=300, step=10, label="Max output tokens")
                generate_btn = gr.Button("⛏️ Generate")

            with gr.Column(scale=2):
                output = gr.Textbox(lines=15, label="Generated Alpine Story", interactive=False)
        gr.Markdown("<br>")

        # Bind button click to inference
        generate_btn.click(
            fn=inference.generate_text,
            inputs=[prompt, max_tokens],
            outputs=output
        )

    app.launch(server_name="0.0.0.0", server_port=7860)


if __name__ == '__main__':
    os.chdir(os.path.dirname(os.path.abspath(__file__)))

    # Define the configuration
    cfg = {
        'cuda_id': 0,
        'model_type': 'transformer',
        'repo_id': "Borzyszkowski/AlpineLLM-Tiny-10M-Base",
        'model_name': "best_model.pt",
        'cache_dir': "./model-cache",
    }
    cfg = Config(cfg)

    # Define the hyperparameters
    hyperparam_cfg={
        "embedding_dim": 384,
        "num_heads": 6,
        "num_layers": 6,
        "dropout": 0.2,
        "context_len": 256,
        "lr": 3e-4,
    }
    hyperparam_cfg = Config(hyperparam_cfg)

    # Ensure model weights are available
    shutil.rmtree(cfg.cache_dir, ignore_errors=True)
    cfg.load_weights_path = download_model(cfg)

    # Start the application
    inference = AlpineLLMInference(cfg, hyperparam_cfg)
    start_app()