Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| # --- Configuration --- | |
| BASE_MODEL_ID = "Qwen/Qwen3-0.6B" | |
| ADAPTER_MODEL_ID = "4rduino/Qwen3-0.6B-dieter-sft" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # --- Model Loading --- | |
| def load_models(): | |
| """ | |
| Load models on application startup. | |
| This function is decorated with @gr.on(startup=True) to run once when the app starts. | |
| """ | |
| global base_model, finetuned_model, tokenizer | |
| print("Loading base model and tokenizer...") | |
| # Use 4-bit quantization for memory efficiency | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| ) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_ID, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True) | |
| print("Base model loaded.") | |
| print("Loading and applying LoRA adapter...") | |
| # Apply the adapter to the base model to get the fine-tuned model | |
| finetuned_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID) | |
| # Note: After merging, the model is no longer a PeftModel, but a normal CausalLM model. | |
| # We will keep it as a PeftModel to avoid extra memory usage from creating a new merged model object. | |
| print("Models are ready!") | |
| def generate_text(prompt, temperature, max_new_tokens): | |
| """ | |
| Generate text from both the base and the fine-tuned model. | |
| """ | |
| if temperature <= 0: | |
| temperature = 0.01 | |
| inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
| generate_kwargs = { | |
| "max_new_tokens": int(max_new_tokens), | |
| "temperature": float(temperature), | |
| "do_sample": True, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| } | |
| # --- Generate from Base Model --- | |
| print("Generating from base model...") | |
| base_outputs = base_model.generate(**inputs, **generate_kwargs) | |
| base_text = tokenizer.decode(base_outputs[0], skip_special_tokens=True) | |
| # --- Generate from Fine-tuned Model --- | |
| print("Generating from fine-tuned model...") | |
| finetuned_outputs = finetuned_model.generate(**inputs, **generate_kwargs) | |
| finetuned_text = tokenizer.decode(finetuned_outputs[0], skip_special_tokens=True) | |
| print("Generation complete.") | |
| # Return only the newly generated part of the text | |
| base_response = base_text[len(prompt):] | |
| finetuned_response = finetuned_text[len(prompt):] | |
| return base_response, finetuned_response | |
| # --- Gradio Interface --- | |
| css = """ | |
| h1 { text-align: center; } | |
| .gr-box { border-radius: 10px !important; } | |
| .gr-button { background-color: #4CAF50 !important; color: white !important; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
| gr.Markdown("# 🤖 Model Comparison: Base vs. Fine-tuned 'Dieter'") | |
| gr.Markdown( | |
| "Enter a prompt to see how the fine-tuned 'Dieter' model compares to the original Qwen-0.6B base model. " | |
| "The 'Dieter' model was fine-tuned for a creative director persona." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt = gr.Textbox( | |
| label="Your Prompt", | |
| placeholder="e.g., Write a tagline for a new brand of sparkling water.", | |
| lines=4, | |
| ) | |
| with gr.Accordion("Generation Settings", open=False): | |
| temperature = gr.Slider( | |
| minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature" | |
| ) | |
| max_new_tokens = gr.Slider( | |
| minimum=50, maximum=512, value=150, step=1, label="Max New Tokens" | |
| ) | |
| btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=3): | |
| with gr.Tabs(): | |
| with gr.TabItem("Side-by-Side"): | |
| with gr.Row(): | |
| out_base = gr.Textbox(label="Base Model Output", lines=12, interactive=False) | |
| out_finetuned = gr.Textbox(label="Fine-tuned 'Dieter' Output", lines=12, interactive=False) | |
| btn.click( | |
| fn=generate_text, | |
| inputs=[prompt, temperature, max_new_tokens], | |
| outputs=[out_base, out_finetuned], | |
| api_name="compare" | |
| ) | |
| gr.Examples( | |
| [ | |
| ["Write a creative brief for a new, eco-friendly sneaker brand."], | |
| ["Generate three concepts for a new fragrance campaign targeting Gen Z."], | |
| ["What's a bold, unexpected idea for a car commercial?"], | |
| ["Give me some feedback on this headline: 'The Future of Coffee is Here.'"], | |
| ], | |
| inputs=[prompt], | |
| ) | |
| demo.launch() | |