Krish-05 commited on
Commit
f27d99c
·
verified ·
1 Parent(s): 0eb3802

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +314 -1
README.md CHANGED
@@ -9,4 +9,317 @@ pinned: false
9
  license: mit
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  license: mit
10
  ---
11
 
12
+ import json
13
+ from collections import defaultdict
14
+ import os
15
+ from datasets import Dataset
16
+ import torch
17
+ from unsloth import FastLanguageModel
18
+ from trl import SFTTrainer
19
+ from transformers import TrainingArguments
20
+ import jsonlines # Recommended for reading .jsonl files
21
+
22
+ def remove_duplicates_jsonl(input_file, output_file):
23
+ """
24
+ Remove duplicate entries from a JSONL file based on prompt and response.
25
+ Preserves the first occurrence of each unique entry.
26
+
27
+ Args:
28
+ input_file (str): Path to input JSONL file
29
+ output_file (str): Path to output JSONL file where deduplicated data will be written
30
+ """
31
+
32
+ # Store unique entries using prompt+response as key
33
+ unique_entries = {}
34
+ duplicate_count = 0
35
+ line_count = 0
36
+
37
+ print(f"Processing file: {input_file}")
38
+
39
+ try:
40
+ # Read input file and track unique entries
41
+ with open(input_file, 'r', encoding='utf-8') as f:
42
+ for line_num, line in enumerate(f, 1):
43
+ try:
44
+ # Skip empty lines
45
+ if not line.strip():
46
+ continue
47
+
48
+ # Parse JSON line
49
+ data = json.loads(line.strip())
50
+ line_count += 1
51
+
52
+ # Create unique key from prompt+response
53
+ unique_key = f"{data.get('prompt', '')}|{data.get('response', '')}"
54
+
55
+ # Track first occurrence of each unique entry
56
+ if unique_key not in unique_entries:
57
+ unique_entries[unique_key] = data
58
+ else:
59
+ duplicate_count += 1
60
+
61
+ except json.JSONDecodeError as e:
62
+ print(f"Error parsing JSON on line {line_num}: {str(e)}")
63
+ continue
64
+
65
+ # Write unique entries to output file
66
+ with open(output_file, 'w', encoding='utf-8') as f:
67
+ for data in unique_entries.values():
68
+ json_str = json.dumps(data, ensure_ascii=False)
69
+ f.write(json_str + '\n')
70
+
71
+ # Print summary
72
+ print("\nDeduplication Summary:")
73
+ print(f"Total lines processed: {line_count}")
74
+ print(f"Duplicate entries removed: {duplicate_count}")
75
+ print(f"Unique entries remaining: {len(unique_entries)}")
76
+ print(f"Output written to: {output_file}")
77
+
78
+ except Exception as e:
79
+ print(f"Error processing file: {str(e)}")
80
+ return
81
+
82
+ if __name__ == "__main__":
83
+ input_file = "prompt_response_pairs.jsonl"
84
+ output_file = "prompt_response_pairs_deduped.jsonl"
85
+
86
+ remove_duplicates_jsonl(input_file, output_file)
87
+
88
+ import json
89
+ import jsonlines # Recommended for reading .jsonl files
90
+
91
+ # --- 1. Load data from .jsonl file ---
92
+ file_path = "prompt_response_pairs_deduped.jsonl" # Change extension to .jsonl
93
+
94
+ # Read .jsonl file
95
+ data = []
96
+ with jsonlines.open(file_path, 'r') as reader:
97
+ for obj in reader:
98
+ data.append(obj)
99
+
100
+ print(f"Loaded {len(data)} entries from {file_path}")
101
+ if len(data) > 0: # Check if data is not empty before trying to print
102
+ print("First entry:", data[0])
103
+
104
+ # For GPU check
105
+ import torch
106
+ print(f"CUDA available: {torch.cuda.is_available()}")
107
+ print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
108
+
109
+ from unsloth import FastLanguageModel
110
+ import torch
111
+
112
+ model_name = "unsloth/llama-3-8b-bnb-4bit"
113
+
114
+ max_seq_length = 2048 # Choose sequence length
115
+ dtype = None # Auto detection
116
+
117
+ # Load model and tokenizer
118
+ model, tokenizer = FastLanguageModel.from_pretrained(
119
+ model_name=model_name,
120
+ max_seq_length=max_seq_length,
121
+ dtype=dtype,
122
+ load_in_4bit=True,
123
+ )
124
+
125
+ import json # Still needed for potential other uses, but not directly for response stringifying here
126
+ from datasets import Dataset
127
+ def format_prompt(example):
128
+ user_prompt = example.get('prompt', '')
129
+ assistant_response = example.get('response', '')
130
+ # Llama 3 chat template
131
+ # We are formatting the training data as a full conversation turn.
132
+ return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{assistant_response}<|eot_id|>"
133
+
134
+ formatted_data = [format_prompt(item) for item in data] # Use 'data' loaded from .jsonl
135
+ dataset = Dataset.from_dict({"text": formatted_data})
136
+
137
+ # Add LoRA adapters
138
+ model = FastLanguageModel.get_peft_model(
139
+ model,
140
+ r=64, # LoRA rank - higher = more capacity, more memory
141
+ target_modules=[
142
+ "q_proj", "k_proj", "v_proj", "o_proj",
143
+ "gate_proj", "up_proj", "down_proj",
144
+ ],
145
+ lora_alpha=128, # LoRA scaling factor (usually 2x rank)
146
+ lora_dropout=0, # Supports any, but = 0 is optimized
147
+ bias="none", # Supports any, but = "none" is optimized
148
+ use_gradient_checkpointing="unsloth", # Unsloth's optimized version
149
+ random_state=3407,
150
+ use_rslora=False, # Rank stabilized LoRA
151
+ loftq_config=None, # LoftQ
152
+ )
153
+
154
+ from trl import SFTTrainer
155
+ from transformers import TrainingArguments
156
+
157
+ # Training arguments optimized for Unsloth
158
+ trainer = SFTTrainer(
159
+ model=model,
160
+ tokenizer=tokenizer,
161
+ train_dataset=dataset,
162
+ dataset_text_field="text",
163
+ max_seq_length=max_seq_length,
164
+ dataset_num_proc=2,
165
+ args=TrainingArguments(
166
+ per_device_train_batch_size=2,
167
+ gradient_accumulation_steps=4, # Effective batch size = 8
168
+ warmup_steps=10,
169
+ num_train_epochs=3,
170
+ learning_rate=2e-4,
171
+ fp16=not torch.cuda.is_bf16_supported(),
172
+ bf16=torch.cuda.is_bf16_supported(),
173
+ logging_steps=25,
174
+ optim="adamw_8bit",
175
+ weight_decay=0.01,
176
+ lr_scheduler_type="linear",
177
+ seed=3407,
178
+ output_dir="outputs",
179
+ save_strategy="epoch",
180
+ save_total_limit=2,
181
+ dataloader_pin_memory=False,
182
+ ),
183
+ )
184
+
185
+ # Train the model
186
+ trainer_stats = trainer.train()
187
+ print("Training complete!")
188
+
189
+ # Option A: Save just the LoRA adapters (most common for continued fine-tuning)
190
+ # This creates a folder with adapter_model.safetensors and tokenizer files.
191
+ lora_model_dir = "./lora_adapters_saved"
192
+ print(f"\nSaving LoRA adapters to: {lora_model_dir}")
193
+ model.save_pretrained(lora_model_dir)
194
+ tokenizer.save_pretrained(lora_model_dir)
195
+ print("LoRA adapters and tokenizer saved!")
196
+
197
+
198
+ # --- 5. Test the fine-tuned model ---
199
+ FastLanguageModel.for_inference(model) # Enable native 2x
200
+
201
+ # Test prompt - adjust the prompt to match the Llama 3 chat template for inference
202
+ # The prompt should be just the user's part of the conversation for inference.
203
+ messages = [
204
+ {"role": "user", "prompt": "i have a question about cancelling order 12345"},
205
+ ]
206
+
207
+ inputs = tokenizer.apply_chat_template(
208
+ messages,
209
+ tokenize=True,
210
+ add_generation_prompt=True, # This adds the assistant turn for generation
211
+ return_tensors="pt",
212
+ ).to("cuda")
213
+
214
+ # Generate response
215
+ print("\nGenerating response with fine-tuned model...")
216
+ outputs = model.generate(
217
+ input_ids=inputs,
218
+ max_new_tokens=256,
219
+ use_cache=True,
220
+ temperature=0.7,
221
+ do_sample=True,
222
+ top_p=0.9,
223
+ )
224
+
225
+ # Decode and print
226
+ response = tokenizer.decode(outputs[0], skip_special_tokens=False) # Keep tokens for inspection
227
+ print(response)
228
+
229
+ # To get just the generated assistant part, you'd typically parse the response string.
230
+ try:
231
+ assistant_start = response.find("<|start_header_id|>assistant<|end_header_id|>\n\n")
232
+ if assistant_start != -1:
233
+ generated_text = response[assistant_start + len("<|start_header_id|>assistant<|end_header_id|>\n\n"):].strip()
234
+ # Remove any trailing <|eot_id|> or other special tokens if present
235
+ generated_text = generated_text.replace("<|eot_id|>", "").strip()
236
+ print("\nExtracted Generated Output:")
237
+ print(generated_text)
238
+ except Exception as e:
239
+ print(f"Could not extract generated text: {e}")
240
+
241
+
242
+ # --- 6. Save the model in GGUF format ---
243
+ # This will save the model with LoRA adapters merged into the base model and quantized.
244
+ # The `model.save_pretrained_gguf` function *already* handles merging
245
+ # and quantization internally before saving to GGUF.
246
+ gguf_output_dir = "gguf_model"
247
+ os.makedirs(gguf_output_dir, exist_ok=True) # Ensure the directory exists
248
+ print(f"\nSaving model to GGUF format in: {gguf_output_dir}")
249
+ model.save_pretrained_gguf(gguf_output_dir, tokenizer, quantization_method="q4_k_m")
250
+ print("Model saved in GGUF format!")
251
+ Fine-Tuning Details
252
+ Base Model: unsloth/llama-3-8b-bnb-4bit (an optimized Llama 3 variant for efficient fine-tuning).
253
+
254
+ Method: LoRA (Low-Rank Adaptation) is used for efficient fine-tuning.
255
+
256
+ Tokenizer: Uses the Llama 3 chat template to format prompts and responses for training.
257
+
258
+ Training Arguments: Configured for adamw_8bit optimizer, linear learning rate scheduler, and saving the best model.
259
+
260
+ Output Format: The fine-tuned LoRA adapters are saved, and the model is also saved in GGUF format for local deployment with Ollama.
261
+
262
+ Architecture and Workflow
263
+ The chatbot operates with a full-stack architecture orchestrated by Nginx, leveraging FastAPI for the backend logic and a React.js frontend for the user interface. Ollama is used to serve the fine-tuned LLM locally within the container.
264
+
265
+
266
+
267
+
268
+
269
+ Conceptual architecture similar to a Hugging Face Space deployment.
270
+
271
+ Components:
272
+ Frontend (React.js): Provides the user interface for interacting with the chatbot, including a text input area and a voice input feature.
273
+
274
+ Nginx: Acts as a reverse proxy, routing requests from the frontend to the appropriate backend services (FastAPI). It also serves the static frontend files.
275
+
276
+ FastAPI (Python):
277
+
278
+ LLM Endpoint (/api/ask): Receives text prompts from the frontend, formats them, sends them to the Ollama-served LLM, and streams the generated responses back to the frontend.
279
+
280
+ Audio Transcription Endpoint (/api/transcribe-audio): Receives audio blobs from the frontend, uses a Whisper model (loaded within FastAPI) to transcribe the audio, and returns the transcribed text.
281
+
282
+ Ollama: A local large language model server. It runs the fine-tuned krishna_choudhary/tinyllama model, making it available for inference via an API that FastAPI interacts with.
283
+
284
+ Whisper Model (dimavz/whisper-tiny): Integrated within the FastAPI application for speech-to-text functionality.
285
+
286
+ Workflow:
287
+ User Input: The user types a message or records a voice message in the React frontend.
288
+
289
+ Voice to Text (if applicable):
290
+
291
+ If a voice message is recorded, the frontend sends the audio blob to FastAPI's /api/transcribe-audio endpoint.
292
+
293
+ FastAPI uses the dimavz/whisper-tiny model to transcribe the audio into text.
294
+
295
+ The transcribed text is then sent back to the frontend.
296
+
297
+ Text Prompt to LLM:
298
+
299
+ Whether the input was typed or transcribed, the frontend sends the text prompt to FastAPI's /api/ask endpoint.
300
+
301
+ LLM Inference (FastAPI & Ollama):
302
+
303
+ FastAPI receives the user prompt.
304
+
305
+ It prepares the prompt to match the Llama 3 chat template required by the krishna_choudhary/tinyllama model.
306
+
307
+ FastAPI then sends this formatted prompt to the locally running Ollama server.
308
+
309
+ Ollama processes the prompt using the fine-tuned krishna_choudhary/tinyllama model.
310
+
311
+ The LLM generates a response, which is streamed back to FastAPI.
312
+
313
+ Response Handling and Token Replacement:
314
+
315
+ FastAPI receives the streaming response from Ollama.
316
+
317
+ The fine-tuned model is designed to generate responses containing placeholders (e.g., {{Order Number}}, {{Online Company Portal Info}}).
318
+
319
+ FastAPI identifies these placeholders in the LLM's raw output.
320
+
321
+ Token Replacement: In a production environment, FastAPI would have logic to dynamically replace these placeholders with actual, real-time data from a database or other internal systems (e.g., fetching a customer's actual order number from a CRM, or providing a real company portal URL). For this project, the model directly generates the response with the placeholders as seen in the training data, demonstrating its ability to recognize and emit structured responses. The actual dynamic replacement logic would be implemented here in FastAPI.
322
+
323
+ The (potentially placeholder-replaced) response is streamed back to the frontend.
324
+
325
+ Display Response: The React frontend receives the streamed response and displays it to the user, providing a real-time conversational experience.