Alok2304's picture
Update app.py
8f75d90 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Define your model name
MODEL_NAME = "Khalid02/fine_tuned_law_llama3_8b_4bit"
# Use the older @st.cache with allow_output_mutation for caching
@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
st.info("Loading model and tokenizer. This may take a while...")
try:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Load model in 4-bit precision
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
load_in_4bit=True, # 4-bit quantization
device_map="auto" # automatically place on GPU/CPU
)
st.success("Model and tokenizer loaded successfully.")
return tokenizer, model
except Exception as e:
st.error(f"Error loading model: {e}")
return None, None
# Load the model and tokenizer
tokenizer, model = load_model_and_tokenizer()
# App title and description
st.title("Law LLaMA3 Fine-Tuned Model Space")
st.write(
"""
This space demonstrates the **Khalid02/fine_tuned_law_llama3_8b_4bit** model.
Enter a prompt below and click **Generate** to see the model in action.
"""
)
# Check if the model and tokenizer are loaded
if tokenizer is None or model is None:
st.error("The model could not be loaded. Check logs for details.")
else:
# Text area for prompt input
prompt = st.text_area("Enter your prompt here:")
# Generate button
if st.button("Generate"):
if not prompt.strip():
st.error("Please enter a valid prompt.")
else:
with st.spinner("Generating response..."):
try:
# Tokenize and move to model device
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
# Generate text
output_ids = model.generate(
input_ids,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
# Decode the generated tokens
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
st.text_area("Generated Text", value=generated_text, height=300)
except Exception as e:
st.error(f"Error during text generation: {e}")