ovinduG's picture
Update README.md
bdad09b verified
---
license: gemma
base_model: google/functiongemma-270m-it
tags:
- text-classification
- domain-classification
- function-calling
- peft
- lora
- gemma
- functiongemma
datasets:
- custom
language:
- en
metrics:
- accuracy
- f1
library_name: transformers
pipeline_tag: text-classification
---
# FunctionGemma Domain Classifier
Fine-tuned **FunctionGemma-270M** for multi-domain query classification using LoRA.
## Model Details
- **Base Model:** [google/functiongemma-270m-it](https://huggingface.co/google/functiongemma-270m-it)
- **Model Size:** 270M parameters (540MB)
- **Fine-tuning Method:** LoRA (Low-Rank Adaptation)
- **Trainable Parameters:** ~7.6M (2.75%)
- **Training Time:** 23.3 minutes
- **Hardware:** GPU (memory optimized for <5GB VRAM)
## Performance
```
Accuracy: 95.51%
F1 Score (Weighted): 0.96
F1 Score (Macro): 0.88
Training Loss: 0.3
```
## Supported Domains (17)
1. ambiguous
2. api_generation
3. business
4. coding
5. creative_content
6. data_analysis
7. education
8. general_knowledge
9. geography
10. history
11. law
12. literature
13. mathematics
14. medicine
15. science
16. sensitive
17. technology
## Use Cases
- **Query Routing:** Route user queries to specialized models/services
- **Content Classification:** Categorize text by domain
- **Multi-domain Detection:** Identify queries spanning multiple domains
- **Intent Analysis:** Understand query context and domain
## Quick Start
### Installation
```bash
pip install transformers peft torch
```
### Inference
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import json
# Load model
base_model = AutoModelForCausalLM.from_pretrained(
"google/functiongemma-270m-it",
torch_dtype=torch.bfloat16,
device_map="auto"
)
model = PeftModel.from_pretrained(base_model, "ovinduG/functiongemma-domain-classifier")
tokenizer = AutoTokenizer.from_pretrained("ovinduG/functiongemma-domain-classifier")
# Classify a query
def classify(text):
# Define function schema
function_def = {
"type": "function",
"function": {
"name": "classify_query_domain",
"description": "Classify query into domains",
"parameters": {
"type": "object",
"properties": {
"primary_domain": {"type": "string"},
"primary_confidence": {"type": "number"},
"is_multi_domain": {"type": "boolean"},
"secondary_domains": {"type": "array"}
}
}
}
}
messages = [
{"role": "developer", "content": "You are a model that can do function calling"},
{"role": "user", "content": text}
]
inputs = tokenizer.apply_chat_template(
messages,
tools=[function_def],
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[-1]:],
skip_special_tokens=True
)
# Parse function call
if "{" in response:
start = response.find("{")")
end = response.rfind("}") + 1
return json.loads(response[start:end])
return {"error": "Failed to parse response"}
# Example
result = classify("Write a Python function to calculate fibonacci numbers")
print(json.dumps(result, indent=2))
```
### Example Output
```json
{
"primary_domain": "coding",
"primary_confidence": 0.95,
"is_multi_domain": false,
"secondary_domains": []
}
```
### Multi-Domain Example
```python
result = classify("Build an ML model to predict customer churn and create REST API endpoints")
print(json.dumps(result, indent=2))
```
```json
{
"primary_domain": "data_analysis",
"primary_confidence": 0.85,
"is_multi_domain": true,
"secondary_domains": [
{
"domain": "api_generation",
"confidence": 0.75
}
]
}
```
## Training Details
### Dataset
- **Total Samples:** 5,046
- **Training Samples:** 3,666
- **Validation Samples:** 690
- **Test Samples:** 690
- **Multi-domain Queries:** 546 (10.8%)
### Training Configuration
```python
# LoRA Configuration
r = 32
lora_alpha = 64
lora_dropout = 0.05
target_modules = ['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
# Training Configuration
num_epochs = 5
batch_size = 4
gradient_accumulation_steps = 8
learning_rate = 0.0003
max_length = 1024
optimizer = "adamw_8bit" # Memory optimized
```
### Memory Optimization
This model was trained with memory optimizations to run on GPUs with <5GB VRAM:
- **8-bit Optimizer:** Reduces optimizer memory by 50%
- **Gradient Checkpointing:** Trades compute for memory
- **Smaller Batches:** 4 samples per batch with gradient accumulation
- **Shorter Sequences:** 1024 tokens max (vs 2048)
**Total VRAM Usage:** ~4GB (vs ~40GB without optimization)
## Performance by Domain
| Domain | Precision | Recall | F1-Score | Support |
|--------|-----------|--------|----------|---------|
| ambiguous | 0.98 | 1.00 | 0.99 | 45 |
| api_generation | 0.98 | 1.00 | 0.99 | 45 |
| business | 0.98 | 0.93 | 0.95 | 44 |
| coding | 0.98 | 0.96 | 0.97 | 48 |
| creative_content | 0.90 | 1.00 | 0.95 | 45 |
| data_analysis | 0.96 | 0.98 | 0.97 | 46 |
| education | 0.98 | 0.96 | 0.97 | 45 |
| general_knowledge | 0.76 | 0.84 | 0.80 | 45 |
| law | 0.98 | 0.94 | 0.96 | 49 |
| literature | 1.00 | 0.93 | 0.97 | 45 |
| mathematics | 1.00 | 1.00 | 1.00 | 47 |
| medicine | 0.98 | 0.89 | 0.93 | 46 |
| science | 1.00 | 0.98 | 0.99 | 47 |
| sensitive | 0.92 | 1.00 | 0.96 | 45 |
| technology | 1.00 | 0.93 | 0.97 | 46 |
**Overall Accuracy:** 95.51%
## Advantages
-**Tiny Size:** 270M parameters (14x smaller than Phi-3)
-**Fast Inference:** 0.3s on CPU, 0.08s on GPU
-**Low Memory:** Runs on 4GB VRAM
-**High Accuracy:** 95.51% (competitive with larger models)
-**Multi-domain:** Detects queries spanning multiple domains
-**Function Calling:** Built-in structured output
-**Mobile-Ready:** Can deploy on smartphones
## Limitations
- Trained on English queries only
- Performance varies by domain (see table above)
- May struggle with highly ambiguous queries
- Limited to 17 pre-defined domains
## Base Model
- **Base model:** `google/functiongemma`
- **Model family:** Gemma
- **Model owner:** Google LLC
- **Fine-tuning task:** Domain classification
## Acknowledgement & Attribution
This model is built upon Google’s FunctionGemma. Use of this model is subject to the **Gemma Terms of Use** and the **Gemma Prohibited Use Policy**:
- **Gemma Terms of Use (March 24, 2025):**
[https://ai.google.dev/gemma/terms](https://ai.google.dev/gemma/terms)
- **Gemma Prohibited Use Policy (February 21, 2024):**
[https://ai.google.dev/gemma/prohibited_use_policy](https://ai.google.dev/gemma/prohibited_use_policy)
Users must comply with these policies when using, modifying, or distributing this model or its derivatives.
## License
This model follows the same terms as **Google’s Gemma models**. Please review the above links for full license and usage restrictions.
## Recommended Hugging Face Metadata
```yaml
license: gemma
base_model: google/functiongemma
tags:
- text-classification
- domain-classification
- gemma
- functiongemma