Rainbowdesign's picture
Update app.py
9375dcf verified
raw
history blame contribute delete
959 Bytes
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gradio as gr
# Load tokenizer & model
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
model = AutoModelForCausalLM.from_pretrained(
"codellama/CodeLlama-7b-Instruct-hf",
device_map="auto"
)
def chat_fn(user_message):
messages = [{"role": "user", "content": user_message}]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=100)
return tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:])
# Gradio UI
demo = gr.Interface(
fn=chat_fn,
inputs=gr.Textbox(label="User Input"),
outputs=gr.Textbox(label="Model Output"),
title="CodeLlama-7b-Instruct Chat"
)
if __name__ == "__main__":
demo.launch()