vishnu714's picture
Update app.py
da747c4 verified
import gradio as gr
from ctransformers import AutoModelForCausalLM
import spaces
# Initialize the Llama model outside of the GPU context to avoid blocking issues
def load_model():
try:
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/Llama-2-7B-GGUF",
model_file="llama-2-7b.Q5_K_M.gguf",
model_type="llama",
max_new_tokens=512,
context_length=2048,
)
return model
except Exception as e:
print(f"Error loading model: {e}")
raise e
# Load the model (outside of GPU context)
llm = load_model()
# Create a function to generate the response
def generate_response(task_type, input_text):
prompts = {
"question": f"Human: Answer the following question: {input_text}\n\nAssistant: ",
"story": f"Human: Write a short story about: {input_text}\n\nAssistant: ",
"code": f"Human: Write Python code for: {input_text}\n\nAssistant: ",
"creative": f"Human: Create something creative based on: {input_text}\n\nAssistant: "
}
prompt = prompts.get(task_type, prompts["question"])
try:
with spaces.GPU(): # Use GPU only during the generation process
output = llm(prompt)
return output.strip()
except Exception as e:
print(f"An error occurred during generation: {e}")
return f"An error occurred: {str(e)}"
# Create the Gradio interface
with gr.Blocks(title="Advanced Llama-2-7b Interface") as iface:
gr.Markdown("# Advanced Llama-2-7b Chat Interface")
gr.Markdown("This interface allows you to interact with the Llama-2-7b model for various tasks.")
with gr.Row():
with gr.Column(scale=1):
task_type = gr.Radio(
["question", "story", "code", "creative"],
label="Task Type",
info="Select the type of task you want the model to perform."
)
with gr.Column(scale=3):
input_text = gr.Textbox(
lines=4, label="Input", placeholder="Enter your question or prompt here..."
)
submit_btn = gr.Button("Generate")
output = gr.Textbox(lines=10, label="Generated Output")
submit_btn.click(generate_response, inputs=[task_type, input_text], outputs=output)
# Launch the interface
iface.launch(server_name="0.0.0.0", server_port=7860)