data-ject's picture
Update app.py
b5cb89d verified
import gradio as gr
from transformers import LlamaForCausalLM, LlamaTokenizer
import torch
import os
# Define model name and local directory
local_model_dir = "./chameleon-7b"
# Verify that the directory contains necessary files
assert os.path.exists(local_model_dir), f"{local_model_dir} does not exist"
# Load the model and tokenizer from the local directory
tokenizer = LlamaTokenizer.from_pretrained(local_model_dir)
model = LlamaForCausalLM.from_pretrained(local_model_dir)
# Function to generate response
def generate_response(input_text):
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(inputs.input_ids, max_length=500, num_return_sequences=1)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Custom CSS
css = """
body {
background-color: #2e3440;
color: #d8dee9;
font-family: 'Roboto', sans-serif;
}
.gradio-container {
background-color: #3b4252;
border: 2px solid #88c0d0;
border-radius: 10px;
padding: 20px;
}
.gr-button {
background-color: #5e81ac;
color: #d8dee9;
border-radius: 5px;
}
"""
# Define the Gradio interface
with gr.Blocks(css=css) as demo:
gr.Markdown("# Chat with Chameleon-7B", elem_id="header")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Your Input", placeholder="Type your message here...", lines=4)
output_text = gr.Textbox(label="Chameleon-7B Response", placeholder="The model's response will appear here...", lines=4)
submit_button = gr.Button("Submit")
with gr.Column():
gr.Markdown("### Instructions")
gr.Markdown("Enter your message in the input box and press submit to get a response from Chameleon-7B.")
submit_button.click(fn=generate_response, inputs=input_text, outputs=output_text)
# Launch the Gradio app
demo.launch()