|
|
| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| from peft import PeftModel |
|
|
| |
| BASE_MODEL_ID = "Qwen/Qwen3-0.6B" |
| ADAPTER_MODEL_ID = "4rduino/Qwen3-0.6B-dieter-sft" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
|
|
| @gr.on(startup=True) |
| def load_models(): |
| """ |
| Load models on application startup. |
| This function is decorated with @gr.on(startup=True) to run once when the app starts. |
| """ |
| global base_model, finetuned_model, tokenizer |
|
|
| print("Loading base model and tokenizer...") |
| |
| |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.float16, |
| ) |
|
|
| base_model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL_ID, |
| quantization_config=quantization_config, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True) |
| |
| print("Base model loaded.") |
|
|
| print("Loading and applying LoRA adapter...") |
| |
| finetuned_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID) |
| |
| |
| |
| |
| print("Models are ready!") |
|
|
|
|
| def generate_text(prompt, temperature, max_new_tokens): |
| """ |
| Generate text from both the base and the fine-tuned model. |
| """ |
| if temperature <= 0: |
| temperature = 0.01 |
|
|
| inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
| |
| generate_kwargs = { |
| "max_new_tokens": int(max_new_tokens), |
| "temperature": float(temperature), |
| "do_sample": True, |
| "pad_token_id": tokenizer.eos_token_id, |
| } |
|
|
| |
| print("Generating from base model...") |
| base_outputs = base_model.generate(**inputs, **generate_kwargs) |
| base_text = tokenizer.decode(base_outputs[0], skip_special_tokens=True) |
|
|
| |
| print("Generating from fine-tuned model...") |
| finetuned_outputs = finetuned_model.generate(**inputs, **generate_kwargs) |
| finetuned_text = tokenizer.decode(finetuned_outputs[0], skip_special_tokens=True) |
| |
| print("Generation complete.") |
| |
| |
| base_response = base_text[len(prompt):] |
| finetuned_response = finetuned_text[len(prompt):] |
| |
| return base_response, finetuned_response |
|
|
| |
|
|
| css = """ |
| h1 { text-align: center; } |
| .gr-box { border-radius: 10px !important; } |
| .gr-button { background-color: #4CAF50 !important; color: white !important; } |
| """ |
|
|
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: |
| gr.Markdown("# 🤖 Model Comparison: Base vs. Fine-tuned 'Dieter'") |
| gr.Markdown( |
| "Enter a prompt to see how the fine-tuned 'Dieter' model compares to the original Qwen-0.6B base model. " |
| "The 'Dieter' model was fine-tuned for a creative director persona." |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| prompt = gr.Textbox( |
| label="Your Prompt", |
| placeholder="e.g., Write a tagline for a new brand of sparkling water.", |
| lines=4, |
| ) |
| with gr.Accordion("Generation Settings", open=False): |
| temperature = gr.Slider( |
| minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature" |
| ) |
| max_new_tokens = gr.Slider( |
| minimum=50, maximum=512, value=150, step=1, label="Max New Tokens" |
| ) |
| btn = gr.Button("Generate", variant="primary") |
|
|
| with gr.Column(scale=3): |
| with gr.Tabs(): |
| with gr.TabItem("Side-by-Side"): |
| with gr.Row(): |
| out_base = gr.Textbox(label="Base Model Output", lines=12, interactive=False) |
| out_finetuned = gr.Textbox(label="Fine-tuned 'Dieter' Output", lines=12, interactive=False) |
|
|
| btn.click( |
| fn=generate_text, |
| inputs=[prompt, temperature, max_new_tokens], |
| outputs=[out_base, out_finetuned], |
| api_name="compare" |
| ) |
| |
| gr.Examples( |
| [ |
| ["Write a creative brief for a new, eco-friendly sneaker brand."], |
| ["Generate three concepts for a new fragrance campaign targeting Gen Z."], |
| ["What's a bold, unexpected idea for a car commercial?"], |
| ["Give me some feedback on this headline: 'The Future of Coffee is Here.'"], |
| ], |
| inputs=[prompt], |
| ) |
|
|
| demo.launch() |
|
|