|
|
--- |
|
|
license: gemma |
|
|
base_model: google/functiongemma-270m-it |
|
|
tags: |
|
|
- text-classification |
|
|
- domain-classification |
|
|
- function-calling |
|
|
- peft |
|
|
- lora |
|
|
- gemma |
|
|
- functiongemma |
|
|
datasets: |
|
|
- custom |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- accuracy |
|
|
- f1 |
|
|
library_name: transformers |
|
|
pipeline_tag: text-classification |
|
|
--- |
|
|
|
|
|
# FunctionGemma Domain Classifier |
|
|
|
|
|
Fine-tuned **FunctionGemma-270M** for multi-domain query classification using LoRA. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Base Model:** [google/functiongemma-270m-it](https://huggingface.co/google/functiongemma-270m-it) |
|
|
- **Model Size:** 270M parameters (540MB) |
|
|
- **Fine-tuning Method:** LoRA (Low-Rank Adaptation) |
|
|
- **Trainable Parameters:** ~7.6M (2.75%) |
|
|
- **Training Time:** 23.3 minutes |
|
|
- **Hardware:** GPU (memory optimized for <5GB VRAM) |
|
|
|
|
|
## Performance |
|
|
|
|
|
``` |
|
|
Accuracy: 95.51% |
|
|
F1 Score (Weighted): 0.96 |
|
|
F1 Score (Macro): 0.88 |
|
|
Training Loss: 0.3 |
|
|
``` |
|
|
|
|
|
## Supported Domains (17) |
|
|
|
|
|
1. ambiguous |
|
|
2. api_generation |
|
|
3. business |
|
|
4. coding |
|
|
5. creative_content |
|
|
6. data_analysis |
|
|
7. education |
|
|
8. general_knowledge |
|
|
9. geography |
|
|
10. history |
|
|
11. law |
|
|
12. literature |
|
|
13. mathematics |
|
|
14. medicine |
|
|
15. science |
|
|
16. sensitive |
|
|
17. technology |
|
|
|
|
|
## Use Cases |
|
|
|
|
|
- **Query Routing:** Route user queries to specialized models/services |
|
|
- **Content Classification:** Categorize text by domain |
|
|
- **Multi-domain Detection:** Identify queries spanning multiple domains |
|
|
- **Intent Analysis:** Understand query context and domain |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install transformers peft torch |
|
|
``` |
|
|
|
|
|
### Inference |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
import torch |
|
|
import json |
|
|
|
|
|
# Load model |
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
"google/functiongemma-270m-it", |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto" |
|
|
) |
|
|
model = PeftModel.from_pretrained(base_model, "ovinduG/functiongemma-domain-classifier") |
|
|
tokenizer = AutoTokenizer.from_pretrained("ovinduG/functiongemma-domain-classifier") |
|
|
|
|
|
# Classify a query |
|
|
def classify(text): |
|
|
# Define function schema |
|
|
function_def = { |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "classify_query_domain", |
|
|
"description": "Classify query into domains", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"primary_domain": {"type": "string"}, |
|
|
"primary_confidence": {"type": "number"}, |
|
|
"is_multi_domain": {"type": "boolean"}, |
|
|
"secondary_domains": {"type": "array"} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
messages = [ |
|
|
{"role": "developer", "content": "You are a model that can do function calling"}, |
|
|
{"role": "user", "content": text} |
|
|
] |
|
|
|
|
|
inputs = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tools=[function_def], |
|
|
add_generation_prompt=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt" |
|
|
).to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=150, |
|
|
do_sample=False, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = tokenizer.decode( |
|
|
outputs[0][inputs["input_ids"].shape[-1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
# Parse function call |
|
|
if "{" in response: |
|
|
start = response.find("{")") |
|
|
end = response.rfind("}") + 1 |
|
|
return json.loads(response[start:end]) |
|
|
|
|
|
return {"error": "Failed to parse response"} |
|
|
|
|
|
# Example |
|
|
result = classify("Write a Python function to calculate fibonacci numbers") |
|
|
print(json.dumps(result, indent=2)) |
|
|
``` |
|
|
|
|
|
### Example Output |
|
|
|
|
|
```json |
|
|
{ |
|
|
"primary_domain": "coding", |
|
|
"primary_confidence": 0.95, |
|
|
"is_multi_domain": false, |
|
|
"secondary_domains": [] |
|
|
} |
|
|
``` |
|
|
|
|
|
### Multi-Domain Example |
|
|
|
|
|
```python |
|
|
result = classify("Build an ML model to predict customer churn and create REST API endpoints") |
|
|
print(json.dumps(result, indent=2)) |
|
|
``` |
|
|
|
|
|
```json |
|
|
{ |
|
|
"primary_domain": "data_analysis", |
|
|
"primary_confidence": 0.85, |
|
|
"is_multi_domain": true, |
|
|
"secondary_domains": [ |
|
|
{ |
|
|
"domain": "api_generation", |
|
|
"confidence": 0.75 |
|
|
} |
|
|
] |
|
|
} |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Dataset |
|
|
|
|
|
- **Total Samples:** 5,046 |
|
|
- **Training Samples:** 3,666 |
|
|
- **Validation Samples:** 690 |
|
|
- **Test Samples:** 690 |
|
|
- **Multi-domain Queries:** 546 (10.8%) |
|
|
|
|
|
### Training Configuration |
|
|
|
|
|
```python |
|
|
# LoRA Configuration |
|
|
r = 32 |
|
|
lora_alpha = 64 |
|
|
lora_dropout = 0.05 |
|
|
target_modules = ['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'] |
|
|
|
|
|
# Training Configuration |
|
|
num_epochs = 5 |
|
|
batch_size = 4 |
|
|
gradient_accumulation_steps = 8 |
|
|
learning_rate = 0.0003 |
|
|
max_length = 1024 |
|
|
optimizer = "adamw_8bit" # Memory optimized |
|
|
``` |
|
|
|
|
|
### Memory Optimization |
|
|
|
|
|
This model was trained with memory optimizations to run on GPUs with <5GB VRAM: |
|
|
|
|
|
- **8-bit Optimizer:** Reduces optimizer memory by 50% |
|
|
- **Gradient Checkpointing:** Trades compute for memory |
|
|
- **Smaller Batches:** 4 samples per batch with gradient accumulation |
|
|
- **Shorter Sequences:** 1024 tokens max (vs 2048) |
|
|
|
|
|
**Total VRAM Usage:** ~4GB (vs ~40GB without optimization) |
|
|
|
|
|
## Performance by Domain |
|
|
|
|
|
| Domain | Precision | Recall | F1-Score | Support | |
|
|
|--------|-----------|--------|----------|---------| |
|
|
| ambiguous | 0.98 | 1.00 | 0.99 | 45 | |
|
|
| api_generation | 0.98 | 1.00 | 0.99 | 45 | |
|
|
| business | 0.98 | 0.93 | 0.95 | 44 | |
|
|
| coding | 0.98 | 0.96 | 0.97 | 48 | |
|
|
| creative_content | 0.90 | 1.00 | 0.95 | 45 | |
|
|
| data_analysis | 0.96 | 0.98 | 0.97 | 46 | |
|
|
| education | 0.98 | 0.96 | 0.97 | 45 | |
|
|
| general_knowledge | 0.76 | 0.84 | 0.80 | 45 | |
|
|
| law | 0.98 | 0.94 | 0.96 | 49 | |
|
|
| literature | 1.00 | 0.93 | 0.97 | 45 | |
|
|
| mathematics | 1.00 | 1.00 | 1.00 | 47 | |
|
|
| medicine | 0.98 | 0.89 | 0.93 | 46 | |
|
|
| science | 1.00 | 0.98 | 0.99 | 47 | |
|
|
| sensitive | 0.92 | 1.00 | 0.96 | 45 | |
|
|
| technology | 1.00 | 0.93 | 0.97 | 46 | |
|
|
|
|
|
**Overall Accuracy:** 95.51% |
|
|
|
|
|
## Advantages |
|
|
|
|
|
- ✅ **Tiny Size:** 270M parameters (14x smaller than Phi-3) |
|
|
- ✅ **Fast Inference:** 0.3s on CPU, 0.08s on GPU |
|
|
- ✅ **Low Memory:** Runs on 4GB VRAM |
|
|
- ✅ **High Accuracy:** 95.51% (competitive with larger models) |
|
|
- ✅ **Multi-domain:** Detects queries spanning multiple domains |
|
|
- ✅ **Function Calling:** Built-in structured output |
|
|
- ✅ **Mobile-Ready:** Can deploy on smartphones |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- Trained on English queries only |
|
|
- Performance varies by domain (see table above) |
|
|
- May struggle with highly ambiguous queries |
|
|
- Limited to 17 pre-defined domains |
|
|
|
|
|
## Base Model |
|
|
- **Base model:** `google/functiongemma` |
|
|
- **Model family:** Gemma |
|
|
- **Model owner:** Google LLC |
|
|
- **Fine-tuning task:** Domain classification |
|
|
|
|
|
## Acknowledgement & Attribution |
|
|
This model is built upon Google’s FunctionGemma. Use of this model is subject to the **Gemma Terms of Use** and the **Gemma Prohibited Use Policy**: |
|
|
|
|
|
- **Gemma Terms of Use (March 24, 2025):** |
|
|
[https://ai.google.dev/gemma/terms](https://ai.google.dev/gemma/terms) |
|
|
- **Gemma Prohibited Use Policy (February 21, 2024):** |
|
|
[https://ai.google.dev/gemma/prohibited_use_policy](https://ai.google.dev/gemma/prohibited_use_policy) |
|
|
|
|
|
Users must comply with these policies when using, modifying, or distributing this model or its derivatives. |
|
|
|
|
|
## License |
|
|
This model follows the same terms as **Google’s Gemma models**. Please review the above links for full license and usage restrictions. |
|
|
|
|
|
## Recommended Hugging Face Metadata |
|
|
```yaml |
|
|
license: gemma |
|
|
base_model: google/functiongemma |
|
|
tags: |
|
|
- text-classification |
|
|
- domain-classification |
|
|
- gemma |
|
|
- functiongemma |
|
|
|