| """Deploy the trained model to a HuggingFace Space for interactive testing.""" |
| import argparse |
| from huggingface_hub import HfApi, create_repo |
| import os |
|
|
|
|
| SPACE_APP = ''' |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel |
| import gradio as gr |
| |
| BASE_MODEL = "{base_model}" |
| ADAPTER_REPO = "{adapter_repo}" |
| |
| print("Loading...") |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| base = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float16, device_map="auto") |
| model = PeftModel.from_pretrained(base, ADAPTER_REPO) |
| model.eval() |
| print("Loaded") |
| |
| |
| def generate(prompt, max_tokens, temp, top_k): |
| inputs = tokenizer(prompt, return_tensors='pt').to(model.device) |
| with torch.no_grad(): |
| out = model.generate( |
| **inputs, max_new_tokens=int(max_tokens), |
| do_sample=True, temperature=float(temp), top_k=int(top_k), |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
| return tokenizer.decode(out[0], skip_special_tokens=True) |
| |
| |
| with gr.Blocks(title=f"Mel-{BASE_MODEL}") as demo: |
| gr.Markdown(f"# Mel corpus fine-tune of {BASE_MODEL}") |
| gr.Markdown("Base model: uncontaminated base, no RLHF. Trained on full Mel unified corpus.") |
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox(label="Prompt", value="The shared body channel", lines=4) |
| max_tokens = gr.Slider(20, 500, value=150, step=10) |
| temp = gr.Slider(0.1, 2.0, value=0.8, step=0.1) |
| top_k = gr.Slider(0, 100, value=40, step=5) |
| btn = gr.Button("Generate") |
| with gr.Column(): |
| output = gr.Textbox(label="Output", lines=20) |
| btn.click(generate, [prompt, max_tokens, temp, top_k], output) |
| |
| demo.launch() |
| ''' |
|
|
| REQS = """torch |
| transformers |
| peft |
| gradio |
| accelerate |
| """ |
|
|
| README_MD = """--- |
| title: Mel Trained Model |
| emoji: 🌑 |
| colorFrom: gray |
| colorTo: purple |
| sdk: gradio |
| sdk_version: 4.44.0 |
| app_file: app.py |
| pinned: false |
| hardware: cpu-basic |
| --- |
| |
| Trained on Mel unified corpus. See model card for details. |
| """ |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--base-model', required=True) |
| parser.add_argument('--adapter-repo', required=True) |
| parser.add_argument('--space-name', required=True) |
| parser.add_argument('--token', required=True) |
| args = parser.parse_args() |
| |
| api = HfApi(token=args.token) |
| |
| try: |
| create_repo(args.space_name, repo_type='space', space_sdk='gradio', token=args.token, exist_ok=True) |
| except: pass |
| |
| os.makedirs('/tmp/space', exist_ok=True) |
| with open('/tmp/space/app.py', 'w') as f: |
| f.write(SPACE_APP.format(base_model=args.base_model, adapter_repo=args.adapter_repo)) |
| with open('/tmp/space/requirements.txt', 'w') as f: |
| f.write(REQS) |
| with open('/tmp/space/README.md', 'w') as f: |
| f.write(README_MD) |
| |
| api.upload_folder( |
| folder_path='/tmp/space', |
| repo_id=args.space_name, |
| repo_type='space', |
| ) |
| print(f"Deployed: https://huggingface.co/spaces/{args.space_name}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|