4rduino's picture
Upload app.py with huggingface_hub
f21c2f4 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen3-0.6B"
ADAPTER_MODEL_ID = "4rduino/Qwen3-0.6B-dieter-sft"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- Model Loading ---
@gr.on(startup=True)
def load_models():
"""
Load models on application startup.
This function is decorated with @gr.on(startup=True) to run once when the app starts.
"""
global base_model, finetuned_model, tokenizer
print("Loading base model and tokenizer...")
# Use 4-bit quantization for memory efficiency
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
print("Base model loaded.")
print("Loading and applying LoRA adapter...")
# Apply the adapter to the base model to get the fine-tuned model
finetuned_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID)
# Note: After merging, the model is no longer a PeftModel, but a normal CausalLM model.
# We will keep it as a PeftModel to avoid extra memory usage from creating a new merged model object.
print("Models are ready!")
def generate_text(prompt, temperature, max_new_tokens):
"""
Generate text from both the base and the fine-tuned model.
"""
if temperature <= 0:
temperature = 0.01
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
generate_kwargs = {
"max_new_tokens": int(max_new_tokens),
"temperature": float(temperature),
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
}
# --- Generate from Base Model ---
print("Generating from base model...")
base_outputs = base_model.generate(**inputs, **generate_kwargs)
base_text = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
# --- Generate from Fine-tuned Model ---
print("Generating from fine-tuned model...")
finetuned_outputs = finetuned_model.generate(**inputs, **generate_kwargs)
finetuned_text = tokenizer.decode(finetuned_outputs[0], skip_special_tokens=True)
print("Generation complete.")
# Return only the newly generated part of the text
base_response = base_text[len(prompt):]
finetuned_response = finetuned_text[len(prompt):]
return base_response, finetuned_response
# --- Gradio Interface ---
css = """
h1 { text-align: center; }
.gr-box { border-radius: 10px !important; }
.gr-button { background-color: #4CAF50 !important; color: white !important; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown("# 🤖 Model Comparison: Base vs. Fine-tuned 'Dieter'")
gr.Markdown(
"Enter a prompt to see how the fine-tuned 'Dieter' model compares to the original Qwen-0.6B base model. "
"The 'Dieter' model was fine-tuned for a creative director persona."
)
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(
label="Your Prompt",
placeholder="e.g., Write a tagline for a new brand of sparkling water.",
lines=4,
)
with gr.Accordion("Generation Settings", open=False):
temperature = gr.Slider(
minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature"
)
max_new_tokens = gr.Slider(
minimum=50, maximum=512, value=150, step=1, label="Max New Tokens"
)
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=3):
with gr.Tabs():
with gr.TabItem("Side-by-Side"):
with gr.Row():
out_base = gr.Textbox(label="Base Model Output", lines=12, interactive=False)
out_finetuned = gr.Textbox(label="Fine-tuned 'Dieter' Output", lines=12, interactive=False)
btn.click(
fn=generate_text,
inputs=[prompt, temperature, max_new_tokens],
outputs=[out_base, out_finetuned],
api_name="compare"
)
gr.Examples(
[
["Write a creative brief for a new, eco-friendly sneaker brand."],
["Generate three concepts for a new fragrance campaign targeting Gen Z."],
["What's a bold, unexpected idea for a car commercial?"],
["Give me some feedback on this headline: 'The Future of Coffee is Here.'"],
],
inputs=[prompt],
)
demo.launch()