import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel # --- Configuration --- BASE_MODEL_ID = "Qwen/Qwen3-0.6B" ADAPTER_MODEL_ID = "4rduino/Qwen3-0.6B-dieter-sft" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # --- Model Loading --- @gr.on(startup=True) def load_models(): """ Load models on application startup. This function is decorated with @gr.on(startup=True) to run once when the app starts. """ global base_model, finetuned_model, tokenizer print("Loading base model and tokenizer...") # Use 4-bit quantization for memory efficiency quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, quantization_config=quantization_config, device_map="auto", trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True) print("Base model loaded.") print("Loading and applying LoRA adapter...") # Apply the adapter to the base model to get the fine-tuned model finetuned_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID) # Note: After merging, the model is no longer a PeftModel, but a normal CausalLM model. # We will keep it as a PeftModel to avoid extra memory usage from creating a new merged model object. print("Models are ready!") def generate_text(prompt, temperature, max_new_tokens): """ Generate text from both the base and the fine-tuned model. """ if temperature <= 0: temperature = 0.01 inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) generate_kwargs = { "max_new_tokens": int(max_new_tokens), "temperature": float(temperature), "do_sample": True, "pad_token_id": tokenizer.eos_token_id, } # --- Generate from Base Model --- print("Generating from base model...") base_outputs = base_model.generate(**inputs, **generate_kwargs) base_text = tokenizer.decode(base_outputs[0], skip_special_tokens=True) # --- Generate from Fine-tuned Model --- print("Generating from fine-tuned model...") finetuned_outputs = finetuned_model.generate(**inputs, **generate_kwargs) finetuned_text = tokenizer.decode(finetuned_outputs[0], skip_special_tokens=True) print("Generation complete.") # Return only the newly generated part of the text base_response = base_text[len(prompt):] finetuned_response = finetuned_text[len(prompt):] return base_response, finetuned_response # --- Gradio Interface --- css = """ h1 { text-align: center; } .gr-box { border-radius: 10px !important; } .gr-button { background-color: #4CAF50 !important; color: white !important; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: gr.Markdown("# 🤖 Model Comparison: Base vs. Fine-tuned 'Dieter'") gr.Markdown( "Enter a prompt to see how the fine-tuned 'Dieter' model compares to the original Qwen-0.6B base model. " "The 'Dieter' model was fine-tuned for a creative director persona." ) with gr.Row(): with gr.Column(scale=2): prompt = gr.Textbox( label="Your Prompt", placeholder="e.g., Write a tagline for a new brand of sparkling water.", lines=4, ) with gr.Accordion("Generation Settings", open=False): temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) max_new_tokens = gr.Slider( minimum=50, maximum=512, value=150, step=1, label="Max New Tokens" ) btn = gr.Button("Generate", variant="primary") with gr.Column(scale=3): with gr.Tabs(): with gr.TabItem("Side-by-Side"): with gr.Row(): out_base = gr.Textbox(label="Base Model Output", lines=12, interactive=False) out_finetuned = gr.Textbox(label="Fine-tuned 'Dieter' Output", lines=12, interactive=False) btn.click( fn=generate_text, inputs=[prompt, temperature, max_new_tokens], outputs=[out_base, out_finetuned], api_name="compare" ) gr.Examples( [ ["Write a creative brief for a new, eco-friendly sneaker brand."], ["Generate three concepts for a new fragrance campaign targeting Gen Z."], ["What's a bold, unexpected idea for a car commercial?"], ["Give me some feedback on this headline: 'The Future of Coffee is Here.'"], ], inputs=[prompt], ) demo.launch()