jellewas commited on
Commit
129e933
·
verified ·
1 Parent(s): 0173ebf

Upload hf://spaces/jellewas/mistral7b-chat with huggingface_hub

Browse files
Files changed (1) hide show
  1. hf:/spaces/jellewas/mistral7b-chat +190 -0
hf:/spaces/jellewas/mistral7b-chat ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+
5
+ # Model configuration
6
+ TRAINED_MODEL = "AnythingSLM/mistral7b-qlora-output" # Your trained model repository
7
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ # Global variables for model and tokenizer
10
+ model = None
11
+ tokenizer = None
12
+
13
+ def load_model():
14
+ """Load the trained model directly"""
15
+ global model, tokenizer
16
+
17
+ if model is None or tokenizer is None:
18
+ print(f"Loading trained model from {TRAINED_MODEL}...")
19
+
20
+ # Load tokenizer
21
+ tokenizer = AutoTokenizer.from_pretrained(TRAINED_MODEL, trust_remote_code=True)
22
+ if tokenizer.pad_token is None:
23
+ tokenizer.pad_token = tokenizer.eos_token
24
+
25
+ # Load the trained model with quantization
26
+ bnb_config = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_compute_dtype=torch.float16,
29
+ bnb_4bit_use_double_quant=True,
30
+ bnb_4bit_quant_type="nf4"
31
+ )
32
+
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ TRAINED_MODEL,
35
+ quantization_config=bnb_config,
36
+ device_map="auto",
37
+ trust_remote_code=True
38
+ )
39
+
40
+ print("Trained model loaded successfully!")
41
+
42
+ return model, tokenizer
43
+
44
+ def generate_response(message, max_length=512, temperature=0.7, top_p=0.9):
45
+ """Generate a response using the fine-tuned model"""
46
+ try:
47
+ model, tokenizer = load_model()
48
+
49
+ # Format the input message
50
+ if not message.startswith("<s>"):
51
+ # Add instruction format for Mistral
52
+ prompt = f"<s>[INST] {message} [/INST]"
53
+ else:
54
+ prompt = message
55
+
56
+ # Tokenize input
57
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
58
+
59
+ # Generate response
60
+ with torch.no_grad():
61
+ outputs = model.generate(
62
+ **inputs,
63
+ max_length=max_length,
64
+ temperature=temperature,
65
+ top_p=top_p,
66
+ do_sample=True,
67
+ pad_token_id=tokenizer.eos_token_id,
68
+ eos_token_id=tokenizer.eos_token_id,
69
+ )
70
+
71
+ # Decode response
72
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
+
74
+ # Clean up the response (remove the instruction part if present)
75
+ if "[/INST]" in response:
76
+ response = response.split("[/INST]")[-1].strip()
77
+
78
+ return response
79
+
80
+ except Exception as e:
81
+ return f"Error generating response: {str(e)}"
82
+
83
+ def create_interface():
84
+ """Create the Gradio interface"""
85
+ with gr.Blocks(title="Mistral 7B Fine-tuned Chat", theme=gr.themes.Soft()) as interface:
86
+ gr.Markdown("# 🤖 Mistral 7B Fine-tuned Chat")
87
+ gr.Markdown("Chat with a fine-tuned Mistral 7B model using QLoRA adapters.")
88
+
89
+ with gr.Row():
90
+ with gr.Column(scale=4):
91
+ chatbot = gr.Chatbot(
92
+ height=400,
93
+ show_label=False,
94
+ container=True,
95
+ )
96
+ msg = gr.Textbox(
97
+ label="Message",
98
+ placeholder="Type your message here...",
99
+ lines=2,
100
+ )
101
+
102
+ with gr.Column(scale=1):
103
+ with gr.Accordion("Parameters", open=False):
104
+ max_length = gr.Slider(
105
+ minimum=64,
106
+ maximum=2048,
107
+ value=512,
108
+ step=64,
109
+ label="Max Length",
110
+ )
111
+ temperature = gr.Slider(
112
+ minimum=0.1,
113
+ maximum=2.0,
114
+ value=0.7,
115
+ step=0.1,
116
+ label="Temperature",
117
+ )
118
+ top_p = gr.Slider(
119
+ minimum=0.1,
120
+ maximum=1.0,
121
+ value=0.9,
122
+ step=0.1,
123
+ label="Top P",
124
+ )
125
+
126
+ clear_btn = gr.Button("Clear Chat")
127
+ submit_btn = gr.Button("Send", variant="primary")
128
+
129
+ def user_message(message, history):
130
+ if not message.strip():
131
+ return "", history
132
+
133
+ history = history + [[message, None]]
134
+ return "", history
135
+
136
+ def bot_response(history, max_length, temperature, top_p):
137
+ if not history or history[-1][1] is not None:
138
+ return history
139
+
140
+ user_msg = history[-1][0]
141
+ bot_msg = generate_response(user_msg, max_length, temperature, top_p)
142
+
143
+ history[-1][1] = bot_msg
144
+ return history
145
+
146
+ def clear_chat():
147
+ return []
148
+
149
+ msg.submit(
150
+ user_message,
151
+ [msg, chatbot],
152
+ [msg, chatbot],
153
+ queue=False
154
+ ).then(
155
+ bot_response,
156
+ [chatbot, max_length, temperature, top_p],
157
+ chatbot
158
+ )
159
+
160
+ submit_btn.click(
161
+ user_message,
162
+ [msg, chatbot],
163
+ [msg, chatbot],
164
+ queue=False
165
+ ).then(
166
+ bot_response,
167
+ [chatbot, max_length, temperature, top_p],
168
+ chatbot
169
+ )
170
+
171
+ clear_btn.click(clear_chat, outputs=chatbot)
172
+
173
+ gr.Markdown("""
174
+ ### About
175
+ This Space demonstrates a fine-tuned Mistral 7B model using QLoRA (4-bit quantization + LoRA adapters).
176
+
177
+ **Features:**
178
+ - 4-bit quantized base model for memory efficiency
179
+ - LoRA adapters for task-specific fine-tuning
180
+ - Adjustable generation parameters
181
+ - Real-time chat interface
182
+
183
+ **Model:** Mistral 7B Instruct v0.3 base + custom fine-tuning
184
+ """)
185
+
186
+ return interface
187
+
188
+ if __name__ == "__main__":
189
+ interface = create_interface()
190
+ interface.launch()