--- license: gemma base_model: google/functiongemma-270m-it tags: - text-classification - domain-classification - function-calling - peft - lora - gemma - functiongemma datasets: - custom language: - en metrics: - accuracy - f1 library_name: transformers pipeline_tag: text-classification --- # FunctionGemma Domain Classifier Fine-tuned **FunctionGemma-270M** for multi-domain query classification using LoRA. ## Model Details - **Base Model:** [google/functiongemma-270m-it](https://huggingface.co/google/functiongemma-270m-it) - **Model Size:** 270M parameters (540MB) - **Fine-tuning Method:** LoRA (Low-Rank Adaptation) - **Trainable Parameters:** ~7.6M (2.75%) - **Training Time:** 23.3 minutes - **Hardware:** GPU (memory optimized for <5GB VRAM) ## Performance ``` Accuracy: 95.51% F1 Score (Weighted): 0.96 F1 Score (Macro): 0.88 Training Loss: 0.3 ``` ## Supported Domains (17) 1. ambiguous 2. api_generation 3. business 4. coding 5. creative_content 6. data_analysis 7. education 8. general_knowledge 9. geography 10. history 11. law 12. literature 13. mathematics 14. medicine 15. science 16. sensitive 17. technology ## Use Cases - **Query Routing:** Route user queries to specialized models/services - **Content Classification:** Categorize text by domain - **Multi-domain Detection:** Identify queries spanning multiple domains - **Intent Analysis:** Understand query context and domain ## Quick Start ### Installation ```bash pip install transformers peft torch ``` ### Inference ```python from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import torch import json # Load model base_model = AutoModelForCausalLM.from_pretrained( "google/functiongemma-270m-it", torch_dtype=torch.bfloat16, device_map="auto" ) model = PeftModel.from_pretrained(base_model, "ovinduG/functiongemma-domain-classifier") tokenizer = AutoTokenizer.from_pretrained("ovinduG/functiongemma-domain-classifier") # Classify a query def classify(text): # Define function schema function_def = { "type": "function", "function": { "name": "classify_query_domain", "description": "Classify query into domains", "parameters": { "type": "object", "properties": { "primary_domain": {"type": "string"}, "primary_confidence": {"type": "number"}, "is_multi_domain": {"type": "boolean"}, "secondary_domains": {"type": "array"} } } } } messages = [ {"role": "developer", "content": "You are a model that can do function calling"}, {"role": "user", "content": text} ] inputs = tokenizer.apply_chat_template( messages, tools=[function_def], add_generation_prompt=True, return_dict=True, return_tensors="pt" ).to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=150, do_sample=False, pad_token_id=tokenizer.eos_token_id ) response = tokenizer.decode( outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True ) # Parse function call if "{" in response: start = response.find("{")") end = response.rfind("}") + 1 return json.loads(response[start:end]) return {"error": "Failed to parse response"} # Example result = classify("Write a Python function to calculate fibonacci numbers") print(json.dumps(result, indent=2)) ``` ### Example Output ```json { "primary_domain": "coding", "primary_confidence": 0.95, "is_multi_domain": false, "secondary_domains": [] } ``` ### Multi-Domain Example ```python result = classify("Build an ML model to predict customer churn and create REST API endpoints") print(json.dumps(result, indent=2)) ``` ```json { "primary_domain": "data_analysis", "primary_confidence": 0.85, "is_multi_domain": true, "secondary_domains": [ { "domain": "api_generation", "confidence": 0.75 } ] } ``` ## Training Details ### Dataset - **Total Samples:** 5,046 - **Training Samples:** 3,666 - **Validation Samples:** 690 - **Test Samples:** 690 - **Multi-domain Queries:** 546 (10.8%) ### Training Configuration ```python # LoRA Configuration r = 32 lora_alpha = 64 lora_dropout = 0.05 target_modules = ['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'] # Training Configuration num_epochs = 5 batch_size = 4 gradient_accumulation_steps = 8 learning_rate = 0.0003 max_length = 1024 optimizer = "adamw_8bit" # Memory optimized ``` ### Memory Optimization This model was trained with memory optimizations to run on GPUs with <5GB VRAM: - **8-bit Optimizer:** Reduces optimizer memory by 50% - **Gradient Checkpointing:** Trades compute for memory - **Smaller Batches:** 4 samples per batch with gradient accumulation - **Shorter Sequences:** 1024 tokens max (vs 2048) **Total VRAM Usage:** ~4GB (vs ~40GB without optimization) ## Performance by Domain | Domain | Precision | Recall | F1-Score | Support | |--------|-----------|--------|----------|---------| | ambiguous | 0.98 | 1.00 | 0.99 | 45 | | api_generation | 0.98 | 1.00 | 0.99 | 45 | | business | 0.98 | 0.93 | 0.95 | 44 | | coding | 0.98 | 0.96 | 0.97 | 48 | | creative_content | 0.90 | 1.00 | 0.95 | 45 | | data_analysis | 0.96 | 0.98 | 0.97 | 46 | | education | 0.98 | 0.96 | 0.97 | 45 | | general_knowledge | 0.76 | 0.84 | 0.80 | 45 | | law | 0.98 | 0.94 | 0.96 | 49 | | literature | 1.00 | 0.93 | 0.97 | 45 | | mathematics | 1.00 | 1.00 | 1.00 | 47 | | medicine | 0.98 | 0.89 | 0.93 | 46 | | science | 1.00 | 0.98 | 0.99 | 47 | | sensitive | 0.92 | 1.00 | 0.96 | 45 | | technology | 1.00 | 0.93 | 0.97 | 46 | **Overall Accuracy:** 95.51% ## Advantages - ✅ **Tiny Size:** 270M parameters (14x smaller than Phi-3) - ✅ **Fast Inference:** 0.3s on CPU, 0.08s on GPU - ✅ **Low Memory:** Runs on 4GB VRAM - ✅ **High Accuracy:** 95.51% (competitive with larger models) - ✅ **Multi-domain:** Detects queries spanning multiple domains - ✅ **Function Calling:** Built-in structured output - ✅ **Mobile-Ready:** Can deploy on smartphones ## Limitations - Trained on English queries only - Performance varies by domain (see table above) - May struggle with highly ambiguous queries - Limited to 17 pre-defined domains ## Base Model - **Base model:** `google/functiongemma` - **Model family:** Gemma - **Model owner:** Google LLC - **Fine-tuning task:** Domain classification ## Acknowledgement & Attribution This model is built upon Google’s FunctionGemma. Use of this model is subject to the **Gemma Terms of Use** and the **Gemma Prohibited Use Policy**: - **Gemma Terms of Use (March 24, 2025):** [https://ai.google.dev/gemma/terms](https://ai.google.dev/gemma/terms) - **Gemma Prohibited Use Policy (February 21, 2024):** [https://ai.google.dev/gemma/prohibited_use_policy](https://ai.google.dev/gemma/prohibited_use_policy) Users must comply with these policies when using, modifying, or distributing this model or its derivatives. ## License This model follows the same terms as **Google’s Gemma models**. Please review the above links for full license and usage restrictions. ## Recommended Hugging Face Metadata ```yaml license: gemma base_model: google/functiongemma tags: - text-classification - domain-classification - gemma - functiongemma