mylocalmodels / app.py
Juna190825's picture
Update Dockerfile
818b367 verified
raw
history blame
1.24 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load model (will use cached version if available)
model_id = "meta-llama/Llama-2-7b-chat-hf"
# Check for GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
def generate_text(prompt, max_length=200):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate response
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=0.7,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# LLaMA 2 7B Chat Demo")
with gr.Row():
input_text = gr.Textbox(label="Input Prompt", lines=3)
output_text = gr.Textbox(label="Generated Response", lines=3)
generate_btn = gr.Button("Generate")
generate_btn.click(
fn=generate_text,
inputs=input_text,
outputs=output_text
)
demo.launch(server_name="0.0.0.0", server_port=7860)