DenisCare / app.py
dynamodenis254's picture
Update app.py
8ea470f verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from peft import PeftModel, PeftConfig # Necessary for loading the adapter weights
# --- Configuration ---
# 1. Base Llama 2 model used for fine-tuning
BASE_MODEL = "aboonaji/llama2finetune-v2"
# 2. Your newly published adapter model on the Hub
ADAPTER_MODEL = "dynamodenis254/dynamo-denis-llama2finetune-medical"
# --- Model Loading ---
# This function loads the model and runs only once when the app starts
def load_model():
"""Loads the base model and applies the fine-tuned adapter weights."""
print(f"Loading base model: {BASE_MODEL}")
# Check for GPU availability
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Using device: {device}")
# === FIX: Define 4-bit Quantization Configuration ===
# This dramatically reduces memory usage, solving the 'offload_dir' error.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
# bfloat16 is recommended for Llama models on modern GPUs (A100, V100, T4)
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_enable_fp32_cpu_offload=True
)
# Load the base model (ensure trust_remote_code=True for custom Llama models)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config = bnb_config,
dtype=torch.float16, # Use half precision for faster GPU inference
device_map="auto",
trust_remote_code=True,
offload_folder="./offload_base" # === Specify a folder for disk offloading since we are using free CPU not GPU===
)
# === CRITICAL FIX: Prevent the 'multiple adapters' warning/crash ===
# If the base model repo contains an old PEFT config, it is loaded into the
# model's internal _peft_config attribute. We must delete this before
# loading the new adapter to prevent conflicts.
if hasattr(base_model, "_peft_config"):
print("Cleaning up potentially conflicting _peft_config from the base model.")
del base_model._peft_config
# ==================================================================
# Load the Peft (LoRA) adapter weights on top of the base model
model = PeftModel.from_pretrained(
base_model,
ADAPTER_MODEL,
offload_dir="./offload_peft" # === Specify a folder for disk offloading since we are using free CPU not GPU===
)
# Get the tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Create the Hugging Face Pipeline for easy text generation
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
# device=0 if device == "cuda" else -1 # Use GPU 0 if available, otherwise use CPU. NB remove this since we are using device_map="auto"
)
print("Model and Tokenizer loaded successfully.")
return generator
# Load the model outside the prediction function so it runs only once
generator = load_model()
# --- Prediction Function ---
def generate_response(prompt, max_new_tokens=256, temperature=0.7):
"""Generates text using the fine-tuned model."""
# Llama models often work best with a system prompt structure
system_prompt = "You are a specialized medical assistant. Provide concise and accurate information."
formatted_prompt = f"### System:\n{system_prompt}\n\n### User:\n{prompt}\n\n### Assistant:\n"
try:
# Run the generation pipeline
result = generator(
formatted_prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
return_full_text=False # Only return the generated part of the response
)
# Extract the text and clean up any potential trailing newlines
generated_text = result[0]['generated_text'].strip()
return generated_text
except Exception as e:
return f"An error occurred during generation: {e}"
# --- Gradio Interface Setup ---
iface = gr.Interface(
fn=generate_response,
inputs=[
gr.Textbox(lines=4, label="Medical Query (e.g., 'What are the symptoms of type 2 diabetes?')", placeholder="Enter your medical question..."),
gr.Slider(minimum=32, maximum=1024, step=32, value=256, label="Max Response Length", info="Controls the length of the generated answer."),
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Creativity (Temperature)", info="Higher temperature means more creative/risky answers.")
],
outputs=gr.Textbox(lines=10, label="Fine-Tuned Medical Assistant Response"),
title="⚕️ Medical Llama 2 Fine-Tune Demo By Denis Mbugua (dynamodenis254)",
description="This demo uses a Llama 2 model fine-tuned on medical data. Enter a query and observe the specialized response.",
theme="soft"
)
# Launch is handled automatically by Hugging Face Spaces
if __name__ == "__main__":
iface.launch()