phiclassifier / app.py
ovinduG's picture
Upload 2 files
ab2d339 verified
"""
Multi-Domain Classifier - HuggingFace Space
Interactive testing interface for the fine-tuned Phi-3 domain classifier
"""
import gradio as gr
import json
import pandas as pd
from datetime import datetime
import plotly.graph_objects as go
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# ============================================================================
# MODEL LOADING
# ============================================================================
print("πŸ”„ Loading model...")
# Configuration
MODEL_ID = "ovinduG/multi-domain-classifier-phi3" # Your HuggingFace model
BASE_MODEL = "microsoft/Phi-3-mini-4k-instruct"
# Load model
try:
print(f"Loading base model: {BASE_MODEL}")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
print(f"Loading LoRA adapter: {MODEL_ID}")
model = PeftModel.from_pretrained(base_model, MODEL_ID)
print(f"Loading tokenizer: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model.eval()
print("βœ… Model loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise
# ============================================================================
# DOMAINS
# ============================================================================
DOMAINS = [
'coding', 'api_generation', 'mathematics', 'data_analysis',
'science', 'medicine', 'business', 'law', 'technology',
'literature', 'creative_content', 'education',
'general_knowledge', 'ambiguous', 'sensitive'
]
# ============================================================================
# EXAMPLE QUERIES
# ============================================================================
EXAMPLES = [
["Write a Python function to reverse a linked list"],
["Create an OpenAPI specification for a user authentication service"],
["Build a machine learning model to predict customer churn and create REST API endpoints"],
["What are the symptoms and treatment options for diabetes?"],
["Write Python code to solve calculus problems and visualize the results"],
["Design a healthcare app API that uses AI to diagnose diseases from medical images"],
["Explain the theory of relativity"],
["Create a legal document analysis system using NLP and deploy it as a web service"],
["How do I create a marketing strategy for a new product launch?"],
]
# ============================================================================
# CLASSIFIER CLASS
# ============================================================================
class MultiDomainClassifier:
"""Multi-domain classifier for inference"""
def __init__(self, model, tokenizer, domains):
self.model = model
self.tokenizer = tokenizer
self.domains = domains
self.model.eval()
def predict(self, text: str) -> dict:
"""Classify a query"""
# Create prompt
prompt = self._create_prompt(text)
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048
).to(self.model.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=200,
temperature=0.1,
do_sample=False,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=False
)
# Decode
response = self.tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
# Parse
return self._parse_response(response)
def _create_prompt(self, text: str) -> str:
"""Create prompt for classification"""
system_prompt = f"""You are a multi-domain classifier. Classify queries into domains and detect if they span multiple domains.
Available domains:
{', '.join(self.domains)}
Output format (JSON):
{{
"primary_domain": "domain_name",
"primary_confidence": 0.95,
"is_multi_domain": true/false,
"secondary_domains": [
{{"domain": "domain_name", "confidence": 0.85}}
]
}}
Rules:
- primary_domain: Main domain
- primary_confidence: Score (0.0-1.0)
- is_multi_domain: true if multiple domains, false otherwise
- secondary_domains: List (empty if single-domain)"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Classify this query: {text}"}
]
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
def _parse_response(self, response: str) -> dict:
"""Parse model response"""
try:
response_clean = response.strip()
if '```' in response_clean:
parts = response_clean.split('```')
response_clean = parts[1] if len(parts) > 1 else parts[0]
if response_clean.startswith('json'):
response_clean = response_clean[4:].strip()
result = json.loads(response_clean)
return {
"primary_domain": result.get("primary_domain", "general_knowledge"),
"primary_confidence": float(result.get("primary_confidence", 0.5)),
"is_multi_domain": bool(result.get("is_multi_domain", False)),
"secondary_domains": result.get("secondary_domains", [])
}
except:
# Fallback
for domain in self.domains:
if domain in response.lower():
return {
"primary_domain": domain,
"primary_confidence": 0.5,
"is_multi_domain": False,
"secondary_domains": []
}
return {
"primary_domain": "general_knowledge",
"primary_confidence": 0.3,
"is_multi_domain": False,
"secondary_domains": []
}
# Initialize classifier
classifier = MultiDomainClassifier(model, tokenizer, DOMAINS)
# ============================================================================
# UI FUNCTIONS
# ============================================================================
def create_confidence_chart(primary_domain, primary_conf, secondary_domains):
"""Create confidence visualization"""
domains = [primary_domain]
confidences = [primary_conf]
colors = ['#2ecc71']
for sec in secondary_domains[:3]:
domains.append(sec['domain'])
confidences.append(sec['confidence'])
colors.append('#3498db')
fig = go.Figure(data=[
go.Bar(
y=domains,
x=confidences,
orientation='h',
marker=dict(color=colors),
text=[f"{c:.1%}" for c in confidences],
textposition='outside'
)
])
fig.update_layout(
title="Confidence Scores",
xaxis_title="Confidence",
yaxis_title="Domain",
xaxis=dict(range=[0, 1], tickformat='.0%'),
height=max(250, len(domains) * 60),
margin=dict(l=150, r=50, t=50, b=50)
)
return fig
def classify_query(query_text):
"""Main classification function"""
if not query_text or query_text.strip() == "":
return (
"⚠️ Please enter a query",
"",
"",
None
)
try:
# Get prediction
result = classifier.predict(query_text.strip())
# Format outputs
primary_domain = result['primary_domain']
primary_conf = result['primary_confidence']
is_multi = result['is_multi_domain']
secondary_domains = result.get('secondary_domains', [])
# Primary output
conf_emoji = "🟒" if primary_conf > 0.85 else "🟑" if primary_conf > 0.60 else "πŸ”΄"
primary_output = f"""### 🎯 Primary Domain
**Domain:** `{primary_domain.upper()}`
**Confidence:** {conf_emoji} **{primary_conf:.1%}**
**Multi-Domain:** {'βœ… Yes' if is_multi else '❌ No'}
"""
# Secondary output
if secondary_domains:
secondary_output = "### πŸ“Œ Secondary Domains\n\n"
for i, sec in enumerate(secondary_domains, 1):
sec_conf = sec['confidence']
sec_emoji = "🟒" if sec_conf > 0.70 else "🟑" if sec_conf > 0.50 else "πŸ”΄"
secondary_output += f"{i}. **{sec['domain']}** {sec_emoji} {sec_conf:.1%}\n"
else:
secondary_output = "### πŸ“Œ Secondary Domains\n\n*None (single-domain query)*"
# JSON output
json_output = f"""```json
{json.dumps(result, indent=2)}
```"""
# Chart
chart = create_confidence_chart(primary_domain, primary_conf, secondary_domains)
return primary_output, secondary_output, json_output, chart
except Exception as e:
return f"❌ Error: {str(e)}", "", "", None
# ============================================================================
# GRADIO INTERFACE
# ============================================================================
with gr.Blocks(theme=gr.themes.Soft(), title="Multi-Domain Classifier") as demo:
gr.Markdown("""
# 🎯 Multi-Domain Classifier
Fine-tuned **Phi-3** model for classifying queries into 15+ domains with multi-domain detection.
**Model:** [ovinduG/multi-domain-classifier-phi3](https://huggingface.co/ovinduG/multi-domain-classifier-phi3)
""")
with gr.Row():
with gr.Column(scale=2):
query_input = gr.Textbox(
label="πŸ“ Enter Your Query",
placeholder="Type your question here...",
lines=3
)
with gr.Row():
classify_btn = gr.Button("πŸš€ Classify", variant="primary", size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg")
gr.Markdown("### πŸ’‘ Example Queries")
gr.Examples(
examples=EXAMPLES,
inputs=query_input,
label="Click to try"
)
with gr.Column(scale=1):
gr.Markdown("""
### πŸ“Š About
**Domains (15):**
- πŸ–₯️ Coding
- πŸ”Œ API Generation
- πŸ“ Mathematics
- πŸ“Š Data Analysis
- πŸ”¬ Science
- πŸ₯ Medicine
- πŸ’Ό Business
- βš–οΈ Law
- πŸ’» Technology
- πŸ“š Literature
- 🎨 Creative Content
- πŸŽ“ Education
- 🌍 General Knowledge
- ❓ Ambiguous
- πŸ”’ Sensitive
**Features:**
- Primary domain detection
- Multi-domain flagging
- Secondary domain ranking
- Confidence scores
""")
gr.Markdown("---")
gr.Markdown("## πŸ“ˆ Results")
with gr.Row():
primary_output = gr.Markdown(label="Primary Domain")
secondary_output = gr.Markdown(label="Secondary Domains")
confidence_plot = gr.Plot(label="Confidence Visualization")
with gr.Accordion("πŸ” Raw JSON Output", open=False):
json_output = gr.Markdown()
gr.Markdown("---")
gr.Markdown("""
### πŸ”— Links
- [Model Repository](https://huggingface.co/ovinduG/multi-domain-classifier-phi3)
- [Base Model: Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Feedback & Issues](https://huggingface.co/ovinduG/multi-domain-classifier-phi3/discussions)
""")
# Events
classify_btn.click(
fn=classify_query,
inputs=query_input,
outputs=[primary_output, secondary_output, json_output, confidence_plot]
)
clear_btn.click(
fn=lambda: ("", "", "", None),
outputs=[primary_output, secondary_output, json_output, confidence_plot]
)
query_input.submit(
fn=classify_query,
inputs=query_input,
outputs=[primary_output, secondary_output, json_output, confidence_plot]
)
# ============================================================================
# LAUNCH
# ============================================================================
if __name__ == "__main__":
demo.launch()