| import streamlit as st |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
|
|
| |
| MODEL_NAME = "Khalid02/fine_tuned_law_llama3_8b_4bit" |
|
|
| |
| @st.cache(allow_output_mutation=True) |
| def load_model_and_tokenizer(): |
| st.info("Loading model and tokenizer. This may take a while...") |
| try: |
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| load_in_4bit=True, |
| device_map="auto" |
| ) |
| st.success("Model and tokenizer loaded successfully.") |
| return tokenizer, model |
| except Exception as e: |
| st.error(f"Error loading model: {e}") |
| return None, None |
|
|
| |
| tokenizer, model = load_model_and_tokenizer() |
|
|
| |
| st.title("Law LLaMA3 Fine-Tuned Model Space") |
| st.write( |
| """ |
| This space demonstrates the **Khalid02/fine_tuned_law_llama3_8b_4bit** model. |
| Enter a prompt below and click **Generate** to see the model in action. |
| """ |
| ) |
|
|
| |
| if tokenizer is None or model is None: |
| st.error("The model could not be loaded. Check logs for details.") |
| else: |
| |
| prompt = st.text_area("Enter your prompt here:") |
|
|
| |
| if st.button("Generate"): |
| if not prompt.strip(): |
| st.error("Please enter a valid prompt.") |
| else: |
| with st.spinner("Generating response..."): |
| try: |
| |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) |
|
|
| |
| output_ids = model.generate( |
| input_ids, |
| max_new_tokens=256, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.9, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
|
|
| |
| generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
| st.text_area("Generated Text", value=generated_text, height=300) |
| except Exception as e: |
| st.error(f"Error during text generation: {e}") |