Jalalidin's picture
Update app.py
0523c2b verified
import os
from transformers import LlamaTokenizer, LlamaForCausalLM
BASE_MODEL = "meta-llama/Llama-2-7b-hf"
FINETUNE_MODEL = "CMLI-NLP/CUTE-Llama"
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL, use_auth_token=hf_token)
model = LlamaForCausalLM.from_pretrained(
FINETUNE_MODEL,
device_map="auto",
torch_dtype="auto",
load_in_8bit=True,
use_auth_token=hf_token
)
def generate_response(prompt, max_new_tokens, temperature, top_p):
if not prompt.strip():
return ""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=temperature > 0,
temperature=float(temperature),
top_p=float(top_p),
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
with gr.Blocks(title="CUTE-Llama") as demo:
gr.Markdown("# CUTE-Llama\nMultilingual Llama-2-7B finetune for Chinese, Uyghur, and Tibetan.")
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="Prompt", lines=8, placeholder="Ask in Chinese, Uyghur, Tibetan, or English...")
with gr.Row():
max_new_tokens = gr.Slider(32, 512, value=256, step=8, label="Max new tokens")
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
run = gr.Button("Generate")
with gr.Column(scale=3):
output = gr.Textbox(label="Output", lines=12)
run.click(
fn=generate_response,
inputs=[prompt, max_new_tokens, temperature, top_p],
outputs=output
)
if __name__ == "__main__":
demo.launch()