Spaces:
Runtime error
Runtime error
| import os | |
| from transformers import LlamaTokenizer, LlamaForCausalLM | |
| BASE_MODEL = "meta-llama/Llama-2-7b-hf" | |
| FINETUNE_MODEL = "CMLI-NLP/CUTE-Llama" | |
| hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL, use_auth_token=hf_token) | |
| model = LlamaForCausalLM.from_pretrained( | |
| FINETUNE_MODEL, | |
| device_map="auto", | |
| torch_dtype="auto", | |
| load_in_8bit=True, | |
| use_auth_token=hf_token | |
| ) | |
| def generate_response(prompt, max_new_tokens, temperature, top_p): | |
| if not prompt.strip(): | |
| return "" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=temperature > 0, | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| repetition_penalty=1.1, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| with gr.Blocks(title="CUTE-Llama") as demo: | |
| gr.Markdown("# CUTE-Llama\nMultilingual Llama-2-7B finetune for Chinese, Uyghur, and Tibetan.") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt = gr.Textbox(label="Prompt", lines=8, placeholder="Ask in Chinese, Uyghur, Tibetan, or English...") | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider(32, 512, value=256, step=8, label="Max new tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| run = gr.Button("Generate") | |
| with gr.Column(scale=3): | |
| output = gr.Textbox(label="Output", lines=12) | |
| run.click( | |
| fn=generate_response, | |
| inputs=[prompt, max_new_tokens, temperature, top_p], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |