Spaces:
Sleeping
Sleeping
| """ | |
| Multi-Domain Classifier - HuggingFace Space | |
| Interactive testing interface for the fine-tuned Phi-3 domain classifier | |
| """ | |
| import gradio as gr | |
| import json | |
| import pandas as pd | |
| from datetime import datetime | |
| import plotly.graph_objects as go | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| # ============================================================================ | |
| # MODEL LOADING | |
| # ============================================================================ | |
| print("π Loading model...") | |
| # Configuration | |
| MODEL_ID = "ovinduG/multi-domain-classifier-phi3" # Your HuggingFace model | |
| BASE_MODEL = "microsoft/Phi-3-mini-4k-instruct" | |
| # Load model | |
| try: | |
| print(f"Loading base model: {BASE_MODEL}") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| print(f"Loading LoRA adapter: {MODEL_ID}") | |
| model = PeftModel.from_pretrained(base_model, MODEL_ID) | |
| print(f"Loading tokenizer: {MODEL_ID}") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model.eval() | |
| print("β Model loaded successfully!") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| raise | |
| # ============================================================================ | |
| # DOMAINS | |
| # ============================================================================ | |
| DOMAINS = [ | |
| 'coding', 'api_generation', 'mathematics', 'data_analysis', | |
| 'science', 'medicine', 'business', 'law', 'technology', | |
| 'literature', 'creative_content', 'education', | |
| 'general_knowledge', 'ambiguous', 'sensitive' | |
| ] | |
| # ============================================================================ | |
| # EXAMPLE QUERIES | |
| # ============================================================================ | |
| EXAMPLES = [ | |
| ["Write a Python function to reverse a linked list"], | |
| ["Create an OpenAPI specification for a user authentication service"], | |
| ["Build a machine learning model to predict customer churn and create REST API endpoints"], | |
| ["What are the symptoms and treatment options for diabetes?"], | |
| ["Write Python code to solve calculus problems and visualize the results"], | |
| ["Design a healthcare app API that uses AI to diagnose diseases from medical images"], | |
| ["Explain the theory of relativity"], | |
| ["Create a legal document analysis system using NLP and deploy it as a web service"], | |
| ["How do I create a marketing strategy for a new product launch?"], | |
| ] | |
| # ============================================================================ | |
| # CLASSIFIER CLASS | |
| # ============================================================================ | |
| class MultiDomainClassifier: | |
| """Multi-domain classifier for inference""" | |
| def __init__(self, model, tokenizer, domains): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.domains = domains | |
| self.model.eval() | |
| def predict(self, text: str) -> dict: | |
| """Classify a query""" | |
| # Create prompt | |
| prompt = self._create_prompt(text) | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048 | |
| ).to(self.model.device) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| temperature=0.1, | |
| do_sample=False, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| use_cache=False | |
| ) | |
| # Decode | |
| response = self.tokenizer.decode( | |
| outputs[0][inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| # Parse | |
| return self._parse_response(response) | |
| def _create_prompt(self, text: str) -> str: | |
| """Create prompt for classification""" | |
| system_prompt = f"""You are a multi-domain classifier. Classify queries into domains and detect if they span multiple domains. | |
| Available domains: | |
| {', '.join(self.domains)} | |
| Output format (JSON): | |
| {{ | |
| "primary_domain": "domain_name", | |
| "primary_confidence": 0.95, | |
| "is_multi_domain": true/false, | |
| "secondary_domains": [ | |
| {{"domain": "domain_name", "confidence": 0.85}} | |
| ] | |
| }} | |
| Rules: | |
| - primary_domain: Main domain | |
| - primary_confidence: Score (0.0-1.0) | |
| - is_multi_domain: true if multiple domains, false otherwise | |
| - secondary_domains: List (empty if single-domain)""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Classify this query: {text}"} | |
| ] | |
| return self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| def _parse_response(self, response: str) -> dict: | |
| """Parse model response""" | |
| try: | |
| response_clean = response.strip() | |
| if '```' in response_clean: | |
| parts = response_clean.split('```') | |
| response_clean = parts[1] if len(parts) > 1 else parts[0] | |
| if response_clean.startswith('json'): | |
| response_clean = response_clean[4:].strip() | |
| result = json.loads(response_clean) | |
| return { | |
| "primary_domain": result.get("primary_domain", "general_knowledge"), | |
| "primary_confidence": float(result.get("primary_confidence", 0.5)), | |
| "is_multi_domain": bool(result.get("is_multi_domain", False)), | |
| "secondary_domains": result.get("secondary_domains", []) | |
| } | |
| except: | |
| # Fallback | |
| for domain in self.domains: | |
| if domain in response.lower(): | |
| return { | |
| "primary_domain": domain, | |
| "primary_confidence": 0.5, | |
| "is_multi_domain": False, | |
| "secondary_domains": [] | |
| } | |
| return { | |
| "primary_domain": "general_knowledge", | |
| "primary_confidence": 0.3, | |
| "is_multi_domain": False, | |
| "secondary_domains": [] | |
| } | |
| # Initialize classifier | |
| classifier = MultiDomainClassifier(model, tokenizer, DOMAINS) | |
| # ============================================================================ | |
| # UI FUNCTIONS | |
| # ============================================================================ | |
| def create_confidence_chart(primary_domain, primary_conf, secondary_domains): | |
| """Create confidence visualization""" | |
| domains = [primary_domain] | |
| confidences = [primary_conf] | |
| colors = ['#2ecc71'] | |
| for sec in secondary_domains[:3]: | |
| domains.append(sec['domain']) | |
| confidences.append(sec['confidence']) | |
| colors.append('#3498db') | |
| fig = go.Figure(data=[ | |
| go.Bar( | |
| y=domains, | |
| x=confidences, | |
| orientation='h', | |
| marker=dict(color=colors), | |
| text=[f"{c:.1%}" for c in confidences], | |
| textposition='outside' | |
| ) | |
| ]) | |
| fig.update_layout( | |
| title="Confidence Scores", | |
| xaxis_title="Confidence", | |
| yaxis_title="Domain", | |
| xaxis=dict(range=[0, 1], tickformat='.0%'), | |
| height=max(250, len(domains) * 60), | |
| margin=dict(l=150, r=50, t=50, b=50) | |
| ) | |
| return fig | |
| def classify_query(query_text): | |
| """Main classification function""" | |
| if not query_text or query_text.strip() == "": | |
| return ( | |
| "β οΈ Please enter a query", | |
| "", | |
| "", | |
| None | |
| ) | |
| try: | |
| # Get prediction | |
| result = classifier.predict(query_text.strip()) | |
| # Format outputs | |
| primary_domain = result['primary_domain'] | |
| primary_conf = result['primary_confidence'] | |
| is_multi = result['is_multi_domain'] | |
| secondary_domains = result.get('secondary_domains', []) | |
| # Primary output | |
| conf_emoji = "π’" if primary_conf > 0.85 else "π‘" if primary_conf > 0.60 else "π΄" | |
| primary_output = f"""### π― Primary Domain | |
| **Domain:** `{primary_domain.upper()}` | |
| **Confidence:** {conf_emoji} **{primary_conf:.1%}** | |
| **Multi-Domain:** {'β Yes' if is_multi else 'β No'} | |
| """ | |
| # Secondary output | |
| if secondary_domains: | |
| secondary_output = "### π Secondary Domains\n\n" | |
| for i, sec in enumerate(secondary_domains, 1): | |
| sec_conf = sec['confidence'] | |
| sec_emoji = "π’" if sec_conf > 0.70 else "π‘" if sec_conf > 0.50 else "π΄" | |
| secondary_output += f"{i}. **{sec['domain']}** {sec_emoji} {sec_conf:.1%}\n" | |
| else: | |
| secondary_output = "### π Secondary Domains\n\n*None (single-domain query)*" | |
| # JSON output | |
| json_output = f"""```json | |
| {json.dumps(result, indent=2)} | |
| ```""" | |
| # Chart | |
| chart = create_confidence_chart(primary_domain, primary_conf, secondary_domains) | |
| return primary_output, secondary_output, json_output, chart | |
| except Exception as e: | |
| return f"β Error: {str(e)}", "", "", None | |
| # ============================================================================ | |
| # GRADIO INTERFACE | |
| # ============================================================================ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Multi-Domain Classifier") as demo: | |
| gr.Markdown(""" | |
| # π― Multi-Domain Classifier | |
| Fine-tuned **Phi-3** model for classifying queries into 15+ domains with multi-domain detection. | |
| **Model:** [ovinduG/multi-domain-classifier-phi3](https://huggingface.co/ovinduG/multi-domain-classifier-phi3) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| query_input = gr.Textbox( | |
| label="π Enter Your Query", | |
| placeholder="Type your question here...", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| classify_btn = gr.Button("π Classify", variant="primary", size="lg") | |
| clear_btn = gr.Button("ποΈ Clear", size="lg") | |
| gr.Markdown("### π‘ Example Queries") | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=query_input, | |
| label="Click to try" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| ### π About | |
| **Domains (15):** | |
| - π₯οΈ Coding | |
| - π API Generation | |
| - π Mathematics | |
| - π Data Analysis | |
| - π¬ Science | |
| - π₯ Medicine | |
| - πΌ Business | |
| - βοΈ Law | |
| - π» Technology | |
| - π Literature | |
| - π¨ Creative Content | |
| - π Education | |
| - π General Knowledge | |
| - β Ambiguous | |
| - π Sensitive | |
| **Features:** | |
| - Primary domain detection | |
| - Multi-domain flagging | |
| - Secondary domain ranking | |
| - Confidence scores | |
| """) | |
| gr.Markdown("---") | |
| gr.Markdown("## π Results") | |
| with gr.Row(): | |
| primary_output = gr.Markdown(label="Primary Domain") | |
| secondary_output = gr.Markdown(label="Secondary Domains") | |
| confidence_plot = gr.Plot(label="Confidence Visualization") | |
| with gr.Accordion("π Raw JSON Output", open=False): | |
| json_output = gr.Markdown() | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| ### π Links | |
| - [Model Repository](https://huggingface.co/ovinduG/multi-domain-classifier-phi3) | |
| - [Base Model: Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | |
| - [Feedback & Issues](https://huggingface.co/ovinduG/multi-domain-classifier-phi3/discussions) | |
| """) | |
| # Events | |
| classify_btn.click( | |
| fn=classify_query, | |
| inputs=query_input, | |
| outputs=[primary_output, secondary_output, json_output, confidence_plot] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", "", None), | |
| outputs=[primary_output, secondary_output, json_output, confidence_plot] | |
| ) | |
| query_input.submit( | |
| fn=classify_query, | |
| inputs=query_input, | |
| outputs=[primary_output, secondary_output, json_output, confidence_plot] | |
| ) | |
| # ============================================================================ | |
| # LAUNCH | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| demo.launch() | |