Spaces:
Sleeping
Sleeping
Update README.md
Browse files
README.md
CHANGED
|
@@ -9,4 +9,317 @@ pinned: false
|
|
| 9 |
license: mit
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
license: mit
|
| 10 |
---
|
| 11 |
|
| 12 |
+
import json
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
import os
|
| 15 |
+
from datasets import Dataset
|
| 16 |
+
import torch
|
| 17 |
+
from unsloth import FastLanguageModel
|
| 18 |
+
from trl import SFTTrainer
|
| 19 |
+
from transformers import TrainingArguments
|
| 20 |
+
import jsonlines # Recommended for reading .jsonl files
|
| 21 |
+
|
| 22 |
+
def remove_duplicates_jsonl(input_file, output_file):
|
| 23 |
+
"""
|
| 24 |
+
Remove duplicate entries from a JSONL file based on prompt and response.
|
| 25 |
+
Preserves the first occurrence of each unique entry.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
input_file (str): Path to input JSONL file
|
| 29 |
+
output_file (str): Path to output JSONL file where deduplicated data will be written
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
# Store unique entries using prompt+response as key
|
| 33 |
+
unique_entries = {}
|
| 34 |
+
duplicate_count = 0
|
| 35 |
+
line_count = 0
|
| 36 |
+
|
| 37 |
+
print(f"Processing file: {input_file}")
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
# Read input file and track unique entries
|
| 41 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 42 |
+
for line_num, line in enumerate(f, 1):
|
| 43 |
+
try:
|
| 44 |
+
# Skip empty lines
|
| 45 |
+
if not line.strip():
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
# Parse JSON line
|
| 49 |
+
data = json.loads(line.strip())
|
| 50 |
+
line_count += 1
|
| 51 |
+
|
| 52 |
+
# Create unique key from prompt+response
|
| 53 |
+
unique_key = f"{data.get('prompt', '')}|{data.get('response', '')}"
|
| 54 |
+
|
| 55 |
+
# Track first occurrence of each unique entry
|
| 56 |
+
if unique_key not in unique_entries:
|
| 57 |
+
unique_entries[unique_key] = data
|
| 58 |
+
else:
|
| 59 |
+
duplicate_count += 1
|
| 60 |
+
|
| 61 |
+
except json.JSONDecodeError as e:
|
| 62 |
+
print(f"Error parsing JSON on line {line_num}: {str(e)}")
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
# Write unique entries to output file
|
| 66 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 67 |
+
for data in unique_entries.values():
|
| 68 |
+
json_str = json.dumps(data, ensure_ascii=False)
|
| 69 |
+
f.write(json_str + '\n')
|
| 70 |
+
|
| 71 |
+
# Print summary
|
| 72 |
+
print("\nDeduplication Summary:")
|
| 73 |
+
print(f"Total lines processed: {line_count}")
|
| 74 |
+
print(f"Duplicate entries removed: {duplicate_count}")
|
| 75 |
+
print(f"Unique entries remaining: {len(unique_entries)}")
|
| 76 |
+
print(f"Output written to: {output_file}")
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"Error processing file: {str(e)}")
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
input_file = "prompt_response_pairs.jsonl"
|
| 84 |
+
output_file = "prompt_response_pairs_deduped.jsonl"
|
| 85 |
+
|
| 86 |
+
remove_duplicates_jsonl(input_file, output_file)
|
| 87 |
+
|
| 88 |
+
import json
|
| 89 |
+
import jsonlines # Recommended for reading .jsonl files
|
| 90 |
+
|
| 91 |
+
# --- 1. Load data from .jsonl file ---
|
| 92 |
+
file_path = "prompt_response_pairs_deduped.jsonl" # Change extension to .jsonl
|
| 93 |
+
|
| 94 |
+
# Read .jsonl file
|
| 95 |
+
data = []
|
| 96 |
+
with jsonlines.open(file_path, 'r') as reader:
|
| 97 |
+
for obj in reader:
|
| 98 |
+
data.append(obj)
|
| 99 |
+
|
| 100 |
+
print(f"Loaded {len(data)} entries from {file_path}")
|
| 101 |
+
if len(data) > 0: # Check if data is not empty before trying to print
|
| 102 |
+
print("First entry:", data[0])
|
| 103 |
+
|
| 104 |
+
# For GPU check
|
| 105 |
+
import torch
|
| 106 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 107 |
+
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
|
| 108 |
+
|
| 109 |
+
from unsloth import FastLanguageModel
|
| 110 |
+
import torch
|
| 111 |
+
|
| 112 |
+
model_name = "unsloth/llama-3-8b-bnb-4bit"
|
| 113 |
+
|
| 114 |
+
max_seq_length = 2048 # Choose sequence length
|
| 115 |
+
dtype = None # Auto detection
|
| 116 |
+
|
| 117 |
+
# Load model and tokenizer
|
| 118 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 119 |
+
model_name=model_name,
|
| 120 |
+
max_seq_length=max_seq_length,
|
| 121 |
+
dtype=dtype,
|
| 122 |
+
load_in_4bit=True,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
import json # Still needed for potential other uses, but not directly for response stringifying here
|
| 126 |
+
from datasets import Dataset
|
| 127 |
+
def format_prompt(example):
|
| 128 |
+
user_prompt = example.get('prompt', '')
|
| 129 |
+
assistant_response = example.get('response', '')
|
| 130 |
+
# Llama 3 chat template
|
| 131 |
+
# We are formatting the training data as a full conversation turn.
|
| 132 |
+
return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{assistant_response}<|eot_id|>"
|
| 133 |
+
|
| 134 |
+
formatted_data = [format_prompt(item) for item in data] # Use 'data' loaded from .jsonl
|
| 135 |
+
dataset = Dataset.from_dict({"text": formatted_data})
|
| 136 |
+
|
| 137 |
+
# Add LoRA adapters
|
| 138 |
+
model = FastLanguageModel.get_peft_model(
|
| 139 |
+
model,
|
| 140 |
+
r=64, # LoRA rank - higher = more capacity, more memory
|
| 141 |
+
target_modules=[
|
| 142 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 143 |
+
"gate_proj", "up_proj", "down_proj",
|
| 144 |
+
],
|
| 145 |
+
lora_alpha=128, # LoRA scaling factor (usually 2x rank)
|
| 146 |
+
lora_dropout=0, # Supports any, but = 0 is optimized
|
| 147 |
+
bias="none", # Supports any, but = "none" is optimized
|
| 148 |
+
use_gradient_checkpointing="unsloth", # Unsloth's optimized version
|
| 149 |
+
random_state=3407,
|
| 150 |
+
use_rslora=False, # Rank stabilized LoRA
|
| 151 |
+
loftq_config=None, # LoftQ
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
from trl import SFTTrainer
|
| 155 |
+
from transformers import TrainingArguments
|
| 156 |
+
|
| 157 |
+
# Training arguments optimized for Unsloth
|
| 158 |
+
trainer = SFTTrainer(
|
| 159 |
+
model=model,
|
| 160 |
+
tokenizer=tokenizer,
|
| 161 |
+
train_dataset=dataset,
|
| 162 |
+
dataset_text_field="text",
|
| 163 |
+
max_seq_length=max_seq_length,
|
| 164 |
+
dataset_num_proc=2,
|
| 165 |
+
args=TrainingArguments(
|
| 166 |
+
per_device_train_batch_size=2,
|
| 167 |
+
gradient_accumulation_steps=4, # Effective batch size = 8
|
| 168 |
+
warmup_steps=10,
|
| 169 |
+
num_train_epochs=3,
|
| 170 |
+
learning_rate=2e-4,
|
| 171 |
+
fp16=not torch.cuda.is_bf16_supported(),
|
| 172 |
+
bf16=torch.cuda.is_bf16_supported(),
|
| 173 |
+
logging_steps=25,
|
| 174 |
+
optim="adamw_8bit",
|
| 175 |
+
weight_decay=0.01,
|
| 176 |
+
lr_scheduler_type="linear",
|
| 177 |
+
seed=3407,
|
| 178 |
+
output_dir="outputs",
|
| 179 |
+
save_strategy="epoch",
|
| 180 |
+
save_total_limit=2,
|
| 181 |
+
dataloader_pin_memory=False,
|
| 182 |
+
),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Train the model
|
| 186 |
+
trainer_stats = trainer.train()
|
| 187 |
+
print("Training complete!")
|
| 188 |
+
|
| 189 |
+
# Option A: Save just the LoRA adapters (most common for continued fine-tuning)
|
| 190 |
+
# This creates a folder with adapter_model.safetensors and tokenizer files.
|
| 191 |
+
lora_model_dir = "./lora_adapters_saved"
|
| 192 |
+
print(f"\nSaving LoRA adapters to: {lora_model_dir}")
|
| 193 |
+
model.save_pretrained(lora_model_dir)
|
| 194 |
+
tokenizer.save_pretrained(lora_model_dir)
|
| 195 |
+
print("LoRA adapters and tokenizer saved!")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# --- 5. Test the fine-tuned model ---
|
| 199 |
+
FastLanguageModel.for_inference(model) # Enable native 2x
|
| 200 |
+
|
| 201 |
+
# Test prompt - adjust the prompt to match the Llama 3 chat template for inference
|
| 202 |
+
# The prompt should be just the user's part of the conversation for inference.
|
| 203 |
+
messages = [
|
| 204 |
+
{"role": "user", "prompt": "i have a question about cancelling order 12345"},
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
inputs = tokenizer.apply_chat_template(
|
| 208 |
+
messages,
|
| 209 |
+
tokenize=True,
|
| 210 |
+
add_generation_prompt=True, # This adds the assistant turn for generation
|
| 211 |
+
return_tensors="pt",
|
| 212 |
+
).to("cuda")
|
| 213 |
+
|
| 214 |
+
# Generate response
|
| 215 |
+
print("\nGenerating response with fine-tuned model...")
|
| 216 |
+
outputs = model.generate(
|
| 217 |
+
input_ids=inputs,
|
| 218 |
+
max_new_tokens=256,
|
| 219 |
+
use_cache=True,
|
| 220 |
+
temperature=0.7,
|
| 221 |
+
do_sample=True,
|
| 222 |
+
top_p=0.9,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Decode and print
|
| 226 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=False) # Keep tokens for inspection
|
| 227 |
+
print(response)
|
| 228 |
+
|
| 229 |
+
# To get just the generated assistant part, you'd typically parse the response string.
|
| 230 |
+
try:
|
| 231 |
+
assistant_start = response.find("<|start_header_id|>assistant<|end_header_id|>\n\n")
|
| 232 |
+
if assistant_start != -1:
|
| 233 |
+
generated_text = response[assistant_start + len("<|start_header_id|>assistant<|end_header_id|>\n\n"):].strip()
|
| 234 |
+
# Remove any trailing <|eot_id|> or other special tokens if present
|
| 235 |
+
generated_text = generated_text.replace("<|eot_id|>", "").strip()
|
| 236 |
+
print("\nExtracted Generated Output:")
|
| 237 |
+
print(generated_text)
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f"Could not extract generated text: {e}")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# --- 6. Save the model in GGUF format ---
|
| 243 |
+
# This will save the model with LoRA adapters merged into the base model and quantized.
|
| 244 |
+
# The `model.save_pretrained_gguf` function *already* handles merging
|
| 245 |
+
# and quantization internally before saving to GGUF.
|
| 246 |
+
gguf_output_dir = "gguf_model"
|
| 247 |
+
os.makedirs(gguf_output_dir, exist_ok=True) # Ensure the directory exists
|
| 248 |
+
print(f"\nSaving model to GGUF format in: {gguf_output_dir}")
|
| 249 |
+
model.save_pretrained_gguf(gguf_output_dir, tokenizer, quantization_method="q4_k_m")
|
| 250 |
+
print("Model saved in GGUF format!")
|
| 251 |
+
Fine-Tuning Details
|
| 252 |
+
Base Model: unsloth/llama-3-8b-bnb-4bit (an optimized Llama 3 variant for efficient fine-tuning).
|
| 253 |
+
|
| 254 |
+
Method: LoRA (Low-Rank Adaptation) is used for efficient fine-tuning.
|
| 255 |
+
|
| 256 |
+
Tokenizer: Uses the Llama 3 chat template to format prompts and responses for training.
|
| 257 |
+
|
| 258 |
+
Training Arguments: Configured for adamw_8bit optimizer, linear learning rate scheduler, and saving the best model.
|
| 259 |
+
|
| 260 |
+
Output Format: The fine-tuned LoRA adapters are saved, and the model is also saved in GGUF format for local deployment with Ollama.
|
| 261 |
+
|
| 262 |
+
Architecture and Workflow
|
| 263 |
+
The chatbot operates with a full-stack architecture orchestrated by Nginx, leveraging FastAPI for the backend logic and a React.js frontend for the user interface. Ollama is used to serve the fine-tuned LLM locally within the container.
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
Conceptual architecture similar to a Hugging Face Space deployment.
|
| 270 |
+
|
| 271 |
+
Components:
|
| 272 |
+
Frontend (React.js): Provides the user interface for interacting with the chatbot, including a text input area and a voice input feature.
|
| 273 |
+
|
| 274 |
+
Nginx: Acts as a reverse proxy, routing requests from the frontend to the appropriate backend services (FastAPI). It also serves the static frontend files.
|
| 275 |
+
|
| 276 |
+
FastAPI (Python):
|
| 277 |
+
|
| 278 |
+
LLM Endpoint (/api/ask): Receives text prompts from the frontend, formats them, sends them to the Ollama-served LLM, and streams the generated responses back to the frontend.
|
| 279 |
+
|
| 280 |
+
Audio Transcription Endpoint (/api/transcribe-audio): Receives audio blobs from the frontend, uses a Whisper model (loaded within FastAPI) to transcribe the audio, and returns the transcribed text.
|
| 281 |
+
|
| 282 |
+
Ollama: A local large language model server. It runs the fine-tuned krishna_choudhary/tinyllama model, making it available for inference via an API that FastAPI interacts with.
|
| 283 |
+
|
| 284 |
+
Whisper Model (dimavz/whisper-tiny): Integrated within the FastAPI application for speech-to-text functionality.
|
| 285 |
+
|
| 286 |
+
Workflow:
|
| 287 |
+
User Input: The user types a message or records a voice message in the React frontend.
|
| 288 |
+
|
| 289 |
+
Voice to Text (if applicable):
|
| 290 |
+
|
| 291 |
+
If a voice message is recorded, the frontend sends the audio blob to FastAPI's /api/transcribe-audio endpoint.
|
| 292 |
+
|
| 293 |
+
FastAPI uses the dimavz/whisper-tiny model to transcribe the audio into text.
|
| 294 |
+
|
| 295 |
+
The transcribed text is then sent back to the frontend.
|
| 296 |
+
|
| 297 |
+
Text Prompt to LLM:
|
| 298 |
+
|
| 299 |
+
Whether the input was typed or transcribed, the frontend sends the text prompt to FastAPI's /api/ask endpoint.
|
| 300 |
+
|
| 301 |
+
LLM Inference (FastAPI & Ollama):
|
| 302 |
+
|
| 303 |
+
FastAPI receives the user prompt.
|
| 304 |
+
|
| 305 |
+
It prepares the prompt to match the Llama 3 chat template required by the krishna_choudhary/tinyllama model.
|
| 306 |
+
|
| 307 |
+
FastAPI then sends this formatted prompt to the locally running Ollama server.
|
| 308 |
+
|
| 309 |
+
Ollama processes the prompt using the fine-tuned krishna_choudhary/tinyllama model.
|
| 310 |
+
|
| 311 |
+
The LLM generates a response, which is streamed back to FastAPI.
|
| 312 |
+
|
| 313 |
+
Response Handling and Token Replacement:
|
| 314 |
+
|
| 315 |
+
FastAPI receives the streaming response from Ollama.
|
| 316 |
+
|
| 317 |
+
The fine-tuned model is designed to generate responses containing placeholders (e.g., {{Order Number}}, {{Online Company Portal Info}}).
|
| 318 |
+
|
| 319 |
+
FastAPI identifies these placeholders in the LLM's raw output.
|
| 320 |
+
|
| 321 |
+
Token Replacement: In a production environment, FastAPI would have logic to dynamically replace these placeholders with actual, real-time data from a database or other internal systems (e.g., fetching a customer's actual order number from a CRM, or providing a real company portal URL). For this project, the model directly generates the response with the placeholders as seen in the training data, demonstrating its ability to recognize and emit structured responses. The actual dynamic replacement logic would be implemented here in FastAPI.
|
| 322 |
+
|
| 323 |
+
The (potentially placeholder-replaced) response is streamed back to the frontend.
|
| 324 |
+
|
| 325 |
+
Display Response: The React frontend receives the streamed response and displays it to the user, providing a real-time conversational experience.
|