|
|
--- |
|
|
base_model: openai-community/gpt2-medium |
|
|
pipeline_tag: question-answering |
|
|
license: mit |
|
|
datasets: |
|
|
- SohamGhadge/casual-conversation |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- accuracy |
|
|
library_name: peft |
|
|
--- |
|
|
## 🧠 Fine-Tuned GPT-2 Medium for Conversational AI |
|
|
|
|
|
This project fine-tunes the `gpt2-medium` language model to support natural, casual **conversational dialogue** using **PEFT + LoRA**. |
|
|
|
|
|
--- |
|
|
|
|
|
### 🚀 Model Summary |
|
|
|
|
|
* **Base model**: `gpt2-medium` |
|
|
* **Objective**: Enable natural question-answering and dialogue |
|
|
* **Training method**: Supervised Fine-Tuning (SFT) using PEFT with LoRA adapters |
|
|
* **Tokenizer**: `gpt2` (same as base model) |
|
|
|
|
|
--- |
|
|
|
|
|
### 📈 Training Metrics |
|
|
|
|
|
| Metric | Value | |
|
|
| ------------------- | -------------- | |
|
|
| Global Steps | 2611 | |
|
|
| Final Training Loss | 2.185 | |
|
|
| Training Runtime | 430.61 seconds | |
|
|
| Samples/sec | 138.41 | |
|
|
| Steps/sec | 17.32 | |
|
|
| Total FLOPs | 1.12 × 10¹⁵ | |
|
|
| Epochs | 7.0 | |
|
|
|
|
|
> These metrics reflect final performance after complete training. |
|
|
|
|
|
--- |
|
|
|
|
|
### 💬 Inference Script |
|
|
|
|
|
Chat with the model using the `talk()` function below: |
|
|
|
|
|
```python |
|
|
def talk(model=peft_model, tokenizer=tokenizer, device=device): |
|
|
print("Start chatting with the bot! Type 'exit' to stop.\n") |
|
|
while True: |
|
|
question = input("You: ") |
|
|
if question.lower() == "exit": |
|
|
print("Goodbye!") |
|
|
break |
|
|
|
|
|
prompt = f"User: {question}\nBot:" |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=20, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = tokenizer.decode( |
|
|
outputs[0][inputs["input_ids"].shape[-1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
# Clean response |
|
|
response = response.split(".") |
|
|
response = ".".join(response[:-1]) + "." |
|
|
print("Bot:", response.strip()) |
|
|
``` |
|
|
|
|
|
* 🤖 **Stateless**: No memory across turns (yet). |
|
|
* 🌱 **Future idea**: Add memory/context for multi-turn dialogue. |
|
|
|
|
|
--- |
|
|
|
|
|
### ⚙️ Quick Setup |
|
|
|
|
|
To use this model locally: |
|
|
|
|
|
```bash |
|
|
pip install transformers peft accelerate |
|
|
``` |
|
|
|
|
|
--- |