""" A simple Gradio web app to interact with the AlpineLLM model """ import gradio as gr import os import shutil import torch from huggingface_hub import hf_hub_download from demo_inference import AlpineLLMInference from config_util import Config HF_TOKEN = os.environ.get("HF_TOKEN", None) def download_model(cfg): """ Download the model weights from Hugging Face Hub """ model_path = hf_hub_download( repo_id=cfg.repo_id, filename=cfg.model_name, token=HF_TOKEN, cache_dir=cfg.cache_dir ) return model_path def start_app(): """ Start the web app via Gradio with custom layout """ with gr.Blocks(css="""#builtwithgradio, .footer, .svelte-1ipelgc {display: none !important;}""") as app: gr.Markdown("

AlpineLLM App

") gr.Markdown( "

" "A domain-specific language model for alpine storytelling.
" "Generate climbing stories, mountain impressions, and expedition-style text." "

" ) with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox( lines=8, label="Your alpine prompt...", placeholder="A dawn climb on the Matterhorn..." ) max_tokens = gr.Slider(50, 1000, value=300, step=10, label="Max output tokens") generate_btn = gr.Button("🚀 Generate") with gr.Column(scale=2): output = gr.Textbox(lines=20, label="Generated Alpine Story", interactive=False) # Bind button click to inference generate_btn.click( fn=inference.generate_text, inputs=[prompt, max_tokens], outputs=output ) app.launch(server_name="0.0.0.0", server_port=7860) if __name__ == '__main__': os.chdir(os.path.dirname(os.path.abspath(__file__))) # Define the configuration cfg = { 'cuda_id': 0, 'model_type': 'transformer', 'repo_id': "Borzyszkowski/AlpineLLM-model", 'model_name': "best_model", 'cache_dir': "./model-cache", } cfg = Config(cfg) # Define the hyperparameters hyperparam_cfg={ "embedding_dim": 384, "num_heads": 6, "num_layers": 6, "dropout": 0.2, "context_len": 256, "lr": 3e-4, } hyperparam_cfg = Config(hyperparam_cfg) # Ensure model weights are available shutil.rmtree(cfg.cache_dir, ignore_errors=True) cfg.load_weights_path = download_model(cfg) # Start the application inference = AlpineLLMInference(cfg, hyperparam_cfg) start_app()