4rduino commited on
Commit
af5c76d
·
verified ·
1 Parent(s): 575e98e

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
+ from peft import PeftModel
6
+
7
+ # --- Configuration ---
8
+ BASE_MODEL_ID = "Qwen/Qwen3-0.6B"
9
+ ADAPTER_MODEL_ID = "4rduino/Qwen3-0.6B-dieter-sft"
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # --- Model Loading ---
13
+
14
+ @gr.on(startup=True)
15
+ def load_models():
16
+ """
17
+ Load models on application startup.
18
+ This function is decorated with @gr.on(startup=True) to run once when the app starts.
19
+ """
20
+ global base_model, finetuned_model, tokenizer
21
+
22
+ print("Loading base model and tokenizer...")
23
+
24
+ # Use 4-bit quantization for memory efficiency
25
+ quantization_config = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_quant_type="nf4",
28
+ bnb_4bit_compute_dtype=torch.float16,
29
+ )
30
+
31
+ base_model = AutoModelForCausalLM.from_pretrained(
32
+ BASE_MODEL_ID,
33
+ quantization_config=quantization_config,
34
+ device_map="auto",
35
+ trust_remote_code=True,
36
+ )
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
39
+
40
+ print("Base model loaded.")
41
+
42
+ print("Loading and applying LoRA adapter...")
43
+ # Apply the adapter to the base model to get the fine-tuned model
44
+ finetuned_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID)
45
+
46
+ # Note: After merging, the model is no longer a PeftModel, but a normal CausalLM model.
47
+ # We will keep it as a PeftModel to avoid extra memory usage from creating a new merged model object.
48
+
49
+ print("Models are ready!")
50
+
51
+
52
+ def generate_text(prompt, temperature, max_new_tokens):
53
+ """
54
+ Generate text from both the base and the fine-tuned model.
55
+ """
56
+ if temperature <= 0:
57
+ temperature = 0.01
58
+
59
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
60
+
61
+ generate_kwargs = {
62
+ "max_new_tokens": int(max_new_tokens),
63
+ "temperature": float(temperature),
64
+ "do_sample": True,
65
+ "pad_token_id": tokenizer.eos_token_id,
66
+ }
67
+
68
+ # --- Generate from Base Model ---
69
+ print("Generating from base model...")
70
+ base_outputs = base_model.generate(**inputs, **generate_kwargs)
71
+ base_text = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
72
+
73
+ # --- Generate from Fine-tuned Model ---
74
+ print("Generating from fine-tuned model...")
75
+ finetuned_outputs = finetuned_model.generate(**inputs, **generate_kwargs)
76
+ finetuned_text = tokenizer.decode(finetuned_outputs[0], skip_special_tokens=True)
77
+
78
+ print("Generation complete.")
79
+
80
+ # Return only the newly generated part of the text
81
+ base_response = base_text[len(prompt):]
82
+ finetuned_response = finetuned_text[len(prompt):]
83
+
84
+ return base_response, finetuned_response
85
+
86
+ # --- Gradio Interface ---
87
+
88
+ css = """
89
+ h1 { text-align: center; }
90
+ .gr-box { border-radius: 10px !important; }
91
+ .gr-button { background-color: #4CAF50 !important; color: white !important; }
92
+ """
93
+
94
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
95
+ gr.Markdown("# 🤖 Model Comparison: Base vs. Fine-tuned 'Dieter'")
96
+ gr.Markdown(
97
+ "Enter a prompt to see how the fine-tuned 'Dieter' model compares to the original Qwen-0.6B base model. "
98
+ "The 'Dieter' model was fine-tuned for a creative director persona."
99
+ )
100
+
101
+ with gr.Row():
102
+ with gr.Column(scale=2):
103
+ prompt = gr.Textbox(
104
+ label="Your Prompt",
105
+ placeholder="e.g., Write a tagline for a new brand of sparkling water.",
106
+ lines=4,
107
+ )
108
+ with gr.Accordion("Generation Settings", open=False):
109
+ temperature = gr.Slider(
110
+ minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature"
111
+ )
112
+ max_new_tokens = gr.Slider(
113
+ minimum=50, maximum=512, value=150, step=1, label="Max New Tokens"
114
+ )
115
+ btn = gr.Button("Generate", variant="primary")
116
+
117
+ with gr.Column(scale=3):
118
+ with gr.Tabs():
119
+ with gr.TabItem("Side-by-Side"):
120
+ with gr.Row():
121
+ out_base = gr.Textbox(label="Base Model Output", lines=12, interactive=False)
122
+ out_finetuned = gr.Textbox(label="Fine-tuned 'Dieter' Output", lines=12, interactive=False)
123
+
124
+ btn.click(
125
+ fn=generate_text,
126
+ inputs=[prompt, temperature, max_new_tokens],
127
+ outputs=[out_base, out_finetuned],
128
+ api_name="compare"
129
+ )
130
+
131
+ gr.Examples(
132
+ [
133
+ ["Write a creative brief for a new, eco-friendly sneaker brand."],
134
+ ["Generate three concepts for a new fragrance campaign targeting Gen Z."],
135
+ ["What's a bold, unexpected idea for a car commercial?"],
136
+ ["Give me some feedback on this headline: 'The Future of Coffee is Here.'"],
137
+ ],
138
+ inputs=[prompt],
139
+ )
140
+
141
+ demo.launch()