""" Multi-Domain Classifier - HuggingFace Space Interactive testing interface for the fine-tuned Phi-3 domain classifier """ import gradio as gr import json import pandas as pd from datetime import datetime import plotly.graph_objects as go import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel # ============================================================================ # MODEL LOADING # ============================================================================ print("🔄 Loading model...") # Configuration MODEL_ID = "ovinduG/multi-domain-classifier-phi3" # Your HuggingFace model BASE_MODEL = "microsoft/Phi-3-mini-4k-instruct" # Load model try: print(f"Loading base model: {BASE_MODEL}") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True ) print(f"Loading LoRA adapter: {MODEL_ID}") model = PeftModel.from_pretrained(base_model, MODEL_ID) print(f"Loading tokenizer: {MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model.eval() print("✅ Model loaded successfully!") except Exception as e: print(f"❌ Error loading model: {e}") raise # ============================================================================ # DOMAINS # ============================================================================ DOMAINS = [ 'coding', 'api_generation', 'mathematics', 'data_analysis', 'science', 'medicine', 'business', 'law', 'technology', 'literature', 'creative_content', 'education', 'general_knowledge', 'ambiguous', 'sensitive' ] # ============================================================================ # EXAMPLE QUERIES # ============================================================================ EXAMPLES = [ ["Write a Python function to reverse a linked list"], ["Create an OpenAPI specification for a user authentication service"], ["Build a machine learning model to predict customer churn and create REST API endpoints"], ["What are the symptoms and treatment options for diabetes?"], ["Write Python code to solve calculus problems and visualize the results"], ["Design a healthcare app API that uses AI to diagnose diseases from medical images"], ["Explain the theory of relativity"], ["Create a legal document analysis system using NLP and deploy it as a web service"], ["How do I create a marketing strategy for a new product launch?"], ] # ============================================================================ # CLASSIFIER CLASS # ============================================================================ class MultiDomainClassifier: """Multi-domain classifier for inference""" def __init__(self, model, tokenizer, domains): self.model = model self.tokenizer = tokenizer self.domains = domains self.model.eval() def predict(self, text: str) -> dict: """Classify a query""" # Create prompt prompt = self._create_prompt(text) # Tokenize inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048 ).to(self.model.device) # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=200, temperature=0.1, do_sample=False, pad_token_id=self.tokenizer.pad_token_id, use_cache=False ) # Decode response = self.tokenizer.decode( outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True ) # Parse return self._parse_response(response) def _create_prompt(self, text: str) -> str: """Create prompt for classification""" system_prompt = f"""You are a multi-domain classifier. Classify queries into domains and detect if they span multiple domains. Available domains: {', '.join(self.domains)} Output format (JSON): {{ "primary_domain": "domain_name", "primary_confidence": 0.95, "is_multi_domain": true/false, "secondary_domains": [ {{"domain": "domain_name", "confidence": 0.85}} ] }} Rules: - primary_domain: Main domain - primary_confidence: Score (0.0-1.0) - is_multi_domain: true if multiple domains, false otherwise - secondary_domains: List (empty if single-domain)""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"Classify this query: {text}"} ] return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) def _parse_response(self, response: str) -> dict: """Parse model response""" try: response_clean = response.strip() if '```' in response_clean: parts = response_clean.split('```') response_clean = parts[1] if len(parts) > 1 else parts[0] if response_clean.startswith('json'): response_clean = response_clean[4:].strip() result = json.loads(response_clean) return { "primary_domain": result.get("primary_domain", "general_knowledge"), "primary_confidence": float(result.get("primary_confidence", 0.5)), "is_multi_domain": bool(result.get("is_multi_domain", False)), "secondary_domains": result.get("secondary_domains", []) } except: # Fallback for domain in self.domains: if domain in response.lower(): return { "primary_domain": domain, "primary_confidence": 0.5, "is_multi_domain": False, "secondary_domains": [] } return { "primary_domain": "general_knowledge", "primary_confidence": 0.3, "is_multi_domain": False, "secondary_domains": [] } # Initialize classifier classifier = MultiDomainClassifier(model, tokenizer, DOMAINS) # ============================================================================ # UI FUNCTIONS # ============================================================================ def create_confidence_chart(primary_domain, primary_conf, secondary_domains): """Create confidence visualization""" domains = [primary_domain] confidences = [primary_conf] colors = ['#2ecc71'] for sec in secondary_domains[:3]: domains.append(sec['domain']) confidences.append(sec['confidence']) colors.append('#3498db') fig = go.Figure(data=[ go.Bar( y=domains, x=confidences, orientation='h', marker=dict(color=colors), text=[f"{c:.1%}" for c in confidences], textposition='outside' ) ]) fig.update_layout( title="Confidence Scores", xaxis_title="Confidence", yaxis_title="Domain", xaxis=dict(range=[0, 1], tickformat='.0%'), height=max(250, len(domains) * 60), margin=dict(l=150, r=50, t=50, b=50) ) return fig def classify_query(query_text): """Main classification function""" if not query_text or query_text.strip() == "": return ( "⚠️ Please enter a query", "", "", None ) try: # Get prediction result = classifier.predict(query_text.strip()) # Format outputs primary_domain = result['primary_domain'] primary_conf = result['primary_confidence'] is_multi = result['is_multi_domain'] secondary_domains = result.get('secondary_domains', []) # Primary output conf_emoji = "🟢" if primary_conf > 0.85 else "🟡" if primary_conf > 0.60 else "🔴" primary_output = f"""### 🎯 Primary Domain **Domain:** `{primary_domain.upper()}` **Confidence:** {conf_emoji} **{primary_conf:.1%}** **Multi-Domain:** {'✅ Yes' if is_multi else '❌ No'} """ # Secondary output if secondary_domains: secondary_output = "### 📌 Secondary Domains\n\n" for i, sec in enumerate(secondary_domains, 1): sec_conf = sec['confidence'] sec_emoji = "🟢" if sec_conf > 0.70 else "🟡" if sec_conf > 0.50 else "🔴" secondary_output += f"{i}. **{sec['domain']}** {sec_emoji} {sec_conf:.1%}\n" else: secondary_output = "### 📌 Secondary Domains\n\n*None (single-domain query)*" # JSON output json_output = f"""```json {json.dumps(result, indent=2)} ```""" # Chart chart = create_confidence_chart(primary_domain, primary_conf, secondary_domains) return primary_output, secondary_output, json_output, chart except Exception as e: return f"❌ Error: {str(e)}", "", "", None # ============================================================================ # GRADIO INTERFACE # ============================================================================ with gr.Blocks(theme=gr.themes.Soft(), title="Multi-Domain Classifier") as demo: gr.Markdown(""" # 🎯 Multi-Domain Classifier Fine-tuned **Phi-3** model for classifying queries into 15+ domains with multi-domain detection. **Model:** [ovinduG/multi-domain-classifier-phi3](https://huggingface.co/ovinduG/multi-domain-classifier-phi3) """) with gr.Row(): with gr.Column(scale=2): query_input = gr.Textbox( label="📝 Enter Your Query", placeholder="Type your question here...", lines=3 ) with gr.Row(): classify_btn = gr.Button("🚀 Classify", variant="primary", size="lg") clear_btn = gr.Button("🗑️ Clear", size="lg") gr.Markdown("### 💡 Example Queries") gr.Examples( examples=EXAMPLES, inputs=query_input, label="Click to try" ) with gr.Column(scale=1): gr.Markdown(""" ### 📊 About **Domains (15):** - 🖥️ Coding - 🔌 API Generation - 📐 Mathematics - 📊 Data Analysis - 🔬 Science - 🏥 Medicine - 💼 Business - ⚖️ Law - 💻 Technology - 📚 Literature - 🎨 Creative Content - 🎓 Education - 🌍 General Knowledge - ❓ Ambiguous - 🔒 Sensitive **Features:** - Primary domain detection - Multi-domain flagging - Secondary domain ranking - Confidence scores """) gr.Markdown("---") gr.Markdown("## 📈 Results") with gr.Row(): primary_output = gr.Markdown(label="Primary Domain") secondary_output = gr.Markdown(label="Secondary Domains") confidence_plot = gr.Plot(label="Confidence Visualization") with gr.Accordion("🔍 Raw JSON Output", open=False): json_output = gr.Markdown() gr.Markdown("---") gr.Markdown(""" ### 🔗 Links - [Model Repository](https://huggingface.co/ovinduG/multi-domain-classifier-phi3) - [Base Model: Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Feedback & Issues](https://huggingface.co/ovinduG/multi-domain-classifier-phi3/discussions) """) # Events classify_btn.click( fn=classify_query, inputs=query_input, outputs=[primary_output, secondary_output, json_output, confidence_plot] ) clear_btn.click( fn=lambda: ("", "", "", None), outputs=[primary_output, secondary_output, json_output, confidence_plot] ) query_input.submit( fn=classify_query, inputs=query_input, outputs=[primary_output, secondary_output, json_output, confidence_plot] ) # ============================================================================ # LAUNCH # ============================================================================ if __name__ == "__main__": demo.launch()