metadata
license: gemma
base_model: google/functiongemma-270m-it
tags:
- text-classification
- domain-classification
- function-calling
- peft
- lora
- gemma
- functiongemma
datasets:
- custom
language:
- en
metrics:
- accuracy
- f1
library_name: transformers
pipeline_tag: text-classification
FunctionGemma Domain Classifier
Fine-tuned FunctionGemma-270M for multi-domain query classification using LoRA.
Model Details
- Base Model: google/functiongemma-270m-it
- Model Size: 270M parameters (540MB)
- Fine-tuning Method: LoRA (Low-Rank Adaptation)
- Trainable Parameters: ~7.6M (2.75%)
- Training Time: 23.3 minutes
- Hardware: GPU (memory optimized for <5GB VRAM)
Performance
Accuracy: 95.51%
F1 Score (Weighted): 0.96
F1 Score (Macro): 0.88
Training Loss: 0.3
Supported Domains (17)
- ambiguous
- api_generation
- business
- coding
- creative_content
- data_analysis
- education
- general_knowledge
- geography
- history
- law
- literature
- mathematics
- medicine
- science
- sensitive
- technology
Use Cases
- Query Routing: Route user queries to specialized models/services
- Content Classification: Categorize text by domain
- Multi-domain Detection: Identify queries spanning multiple domains
- Intent Analysis: Understand query context and domain
Quick Start
Installation
pip install transformers peft torch
Inference
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import json
# Load model
base_model = AutoModelForCausalLM.from_pretrained(
"google/functiongemma-270m-it",
torch_dtype=torch.bfloat16,
device_map="auto"
)
model = PeftModel.from_pretrained(base_model, "ovinduG/functiongemma-domain-classifier")
tokenizer = AutoTokenizer.from_pretrained("ovinduG/functiongemma-domain-classifier")
# Classify a query
def classify(text):
# Define function schema
function_def = {
"type": "function",
"function": {
"name": "classify_query_domain",
"description": "Classify query into domains",
"parameters": {
"type": "object",
"properties": {
"primary_domain": {"type": "string"},
"primary_confidence": {"type": "number"},
"is_multi_domain": {"type": "boolean"},
"secondary_domains": {"type": "array"}
}
}
}
}
messages = [
{"role": "developer", "content": "You are a model that can do function calling"},
{"role": "user", "content": text}
]
inputs = tokenizer.apply_chat_template(
messages,
tools=[function_def],
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[-1]:],
skip_special_tokens=True
)
# Parse function call
if "{" in response:
start = response.find("{")")
end = response.rfind("}") + 1
return json.loads(response[start:end])
return {"error": "Failed to parse response"}
# Example
result = classify("Write a Python function to calculate fibonacci numbers")
print(json.dumps(result, indent=2))
Example Output
{
"primary_domain": "coding",
"primary_confidence": 0.95,
"is_multi_domain": false,
"secondary_domains": []
}
Multi-Domain Example
result = classify("Build an ML model to predict customer churn and create REST API endpoints")
print(json.dumps(result, indent=2))
{
"primary_domain": "data_analysis",
"primary_confidence": 0.85,
"is_multi_domain": true,
"secondary_domains": [
{
"domain": "api_generation",
"confidence": 0.75
}
]
}
Training Details
Dataset
- Total Samples: 5,046
- Training Samples: 3,666
- Validation Samples: 690
- Test Samples: 690
- Multi-domain Queries: 546 (10.8%)
Training Configuration
# LoRA Configuration
r = 32
lora_alpha = 64
lora_dropout = 0.05
target_modules = ['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
# Training Configuration
num_epochs = 5
batch_size = 4
gradient_accumulation_steps = 8
learning_rate = 0.0003
max_length = 1024
optimizer = "adamw_8bit" # Memory optimized
Memory Optimization
This model was trained with memory optimizations to run on GPUs with <5GB VRAM:
- 8-bit Optimizer: Reduces optimizer memory by 50%
- Gradient Checkpointing: Trades compute for memory
- Smaller Batches: 4 samples per batch with gradient accumulation
- Shorter Sequences: 1024 tokens max (vs 2048)
Total VRAM Usage: ~4GB (vs ~40GB without optimization)
Performance by Domain
| Domain | Precision | Recall | F1-Score | Support |
|---|---|---|---|---|
| ambiguous | 0.98 | 1.00 | 0.99 | 45 |
| api_generation | 0.98 | 1.00 | 0.99 | 45 |
| business | 0.98 | 0.93 | 0.95 | 44 |
| coding | 0.98 | 0.96 | 0.97 | 48 |
| creative_content | 0.90 | 1.00 | 0.95 | 45 |
| data_analysis | 0.96 | 0.98 | 0.97 | 46 |
| education | 0.98 | 0.96 | 0.97 | 45 |
| general_knowledge | 0.76 | 0.84 | 0.80 | 45 |
| law | 0.98 | 0.94 | 0.96 | 49 |
| literature | 1.00 | 0.93 | 0.97 | 45 |
| mathematics | 1.00 | 1.00 | 1.00 | 47 |
| medicine | 0.98 | 0.89 | 0.93 | 46 |
| science | 1.00 | 0.98 | 0.99 | 47 |
| sensitive | 0.92 | 1.00 | 0.96 | 45 |
| technology | 1.00 | 0.93 | 0.97 | 46 |
Overall Accuracy: 95.51%
Advantages
- ✅ Tiny Size: 270M parameters (14x smaller than Phi-3)
- ✅ Fast Inference: 0.3s on CPU, 0.08s on GPU
- ✅ Low Memory: Runs on 4GB VRAM
- ✅ High Accuracy: 95.51% (competitive with larger models)
- ✅ Multi-domain: Detects queries spanning multiple domains
- ✅ Function Calling: Built-in structured output
- ✅ Mobile-Ready: Can deploy on smartphones
Limitations
- Trained on English queries only
- Performance varies by domain (see table above)
- May struggle with highly ambiguous queries
- Limited to 17 pre-defined domains
Citation
If you use this model, please cite:
@misc{functiongemma-domain-classifier,
author = {ovinduG},
title = {FunctionGemma Domain Classifier},
year = {2024},
publisher = {HuggingFace},
howpublished = {\url{https://huggingface.co/ovinduG/functiongemma-domain-classifier}}
}
License
This model is based on FunctionGemma and follows the same Gemma License.
Acknowledgments
- Base Model: Google's FunctionGemma-270M
- Training Framework: HuggingFace Transformers + PEFT
- Fine-tuning Method: LoRA (Low-Rank Adaptation)
Built with ❤️ using FunctionGemma