ovinduG's picture
Update app.py
77fcfa5 verified
import gradio as gr
import torch
import json
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from datetime import datetime
# ============================================================
# CONFIGURATION
# ============================================================
HF_MODEL_REPO = "ovinduG/phi3-domain-classifier-98.26" # Your LoRA adapters
BASE_MODEL = "microsoft/Phi-3-mini-4k-instruct" # Base model + tokenizer
# Memory optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# Domain information
DOMAINS = {
"coding": "Programming and software development",
"api_generation": "API design and implementation",
"mathematics": "Mathematical problems and concepts",
"data_analysis": "Data science and analytics",
"science": "Scientific queries and research",
"medicine": "Medical and healthcare topics",
"business": "Business and commerce",
"law": "Legal matters and regulations",
"technology": "Tech industry and products",
"literature": "Books, writing, poetry",
"creative_content": "Art, music, creative work",
"education": "Learning and teaching",
"general_knowledge": "General information",
"ambiguous": "Unclear or multi-interpretation queries",
"sensitive": "Sensitive topics requiring care",
"multi_domain": "Cross-domain queries"
}
PERFECT_DOMAINS = [
"ambiguous", "creative_content", "education", "law",
"literature", "medicine", "science", "sensitive", "technology"
]
# ============================================================
# MODEL LOADING WITH ALL FIXES
# ============================================================
print("=" * 80)
print("πŸ”„ Loading Phi-3 Domain Classifier...")
print("=" * 80)
try:
# Load tokenizer from base model (fixes corruption issue)
print("πŸ“₯ Loading tokenizer from Phi-3 base model...")
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
trust_remote_code=True
)
print("βœ… Tokenizer loaded successfully")
# Load base model with disk offloading for low memory
print("πŸ“₯ Loading Phi-3 base model (this may take a minute)...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
offload_folder="offload",
offload_state_dict=True
)
print("βœ… Base model loaded successfully")
# Load LoRA adapters with offloading
print("πŸ“₯ Loading LoRA adapters (98.26% accuracy)...")
model = PeftModel.from_pretrained(
base_model,
HF_MODEL_REPO,
torch_dtype=torch.float16,
offload_folder="offload"
)
model.config.use_cache = False
print("βœ… LoRA adapters loaded successfully")
print("=" * 80)
print("πŸŽ‰ Model ready! 98.26% accuracy achieved!")
print("=" * 80)
except Exception as e:
print("=" * 80)
print(f"❌ ERROR loading model: {e}")
print("=" * 80)
raise
# ============================================================
# CLASSIFICATION FUNCTION - FULLY FIXED
# ============================================================
def classify_domain(text, temperature=0.1):
"""Classify a query into a domain - Fixed version"""
try:
messages = [
{
"role": "system",
"content": "You are a domain classifier. Respond with JSON containing primary_domain and confidence."
},
{
"role": "user",
"content": f"Classify: {text}"
}
]
# Apply chat template to get text first (compatibility fix)
chat_text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False # Get text first, tokenize separately
)
# Tokenize separately without return_dict (fixes compatibility)
inputs = tokenizer(
chat_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# Move to device
input_ids = inputs["input_ids"].to(model.device)
attention_mask = inputs.get("attention_mask", None)
if attention_mask is not None:
attention_mask = attention_mask.to(model.device)
# Generate with fixed parameters
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=50,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
use_cache=True
)
# Decode response
response = tokenizer.decode(
outputs[0][input_ids.shape[-1]:],
skip_special_tokens=True
)
# Parse JSON response
try:
response_clean = response.strip()
if '```' in response_clean:
response_clean = response_clean.split('```')[1]
if response_clean.startswith('json'):
response_clean = response_clean[4:]
result = json.loads(response_clean.strip())
domain = result.get('primary_domain', 'unknown')
confidence = result.get('confidence', 'medium')
return {
"domain": domain,
"confidence": confidence,
"description": DOMAINS.get(domain, "Unknown domain"),
"is_perfect_domain": domain in PERFECT_DOMAINS,
"is_multi_domain": domain == "multi_domain",
"is_ambiguous": domain == "ambiguous",
"raw_response": response
}
except:
# Fallback: search for domain names in response
for domain in DOMAINS.keys():
if domain in response.lower():
return {
"domain": domain,
"confidence": "medium",
"description": DOMAINS.get(domain, "Unknown domain"),
"is_perfect_domain": domain in PERFECT_DOMAINS,
"is_multi_domain": domain == "multi_domain",
"is_ambiguous": domain == "ambiguous",
"raw_response": response
}
return {
"domain": "unknown",
"confidence": "low",
"description": "Failed to classify",
"is_perfect_domain": False,
"is_multi_domain": False,
"is_ambiguous": False,
"raw_response": response
}
except Exception as e:
return {
"domain": "error",
"confidence": "low",
"description": f"Classification error: {str(e)}",
"is_perfect_domain": False,
"is_multi_domain": False,
"is_ambiguous": False,
"raw_response": str(e)
}
# ============================================================
# GRADIO INTERFACE FUNCTIONS
# ============================================================
def classify_single_query(query, temperature):
"""Classify a single query"""
if not query.strip():
return "⚠️ Please enter a query", "", "", ""
result = classify_domain(query, temperature)
# Format main result
domain = result['domain']
confidence = result['confidence']
description = result['description']
main_output = f"## 🎯 Classification Result\n\n"
main_output += f"**Domain:** `{domain}`\n\n"
main_output += f"**Description:** {description}\n\n"
main_output += f"**Confidence:** {confidence.upper()}\n\n"
# Add badges
badges = []
if result['is_perfect_domain']:
badges.append("⭐ **Perfect Domain** (100% accuracy)")
if result['is_multi_domain']:
badges.append("πŸ”€ **Multi-Domain Query**")
if result['is_ambiguous']:
badges.append("❓ **Ambiguous Query**")
if badges:
main_output += "### Status\n"
for badge in badges:
main_output += f"- {badge}\n"
# Domain info
domain_info = f"### πŸ“š Domain Information\n\n"
domain_info += f"**Category:** {domain}\n\n"
domain_info += f"**Description:** {description}\n\n"
if domain in PERFECT_DOMAINS:
domain_info += "**Accuracy:** ⭐ 100% (Perfect)\n\n"
elif domain == "multi_domain":
domain_info += "**Accuracy:** 71% F1\n\n"
else:
domain_info += "**Accuracy:** 93-99% F1\n\n"
# Classification details
details = f"### πŸ” Classification Details\n\n"
details += f"**Query:** {query}\n\n"
details += f"**Predicted Domain:** {domain}\n\n"
details += f"**Confidence Level:** {confidence}\n\n"
# JSON output
json_output = json.dumps(result, indent=2)
return main_output, domain_info, details, json_output
def classify_batch_queries(queries_text, temperature):
"""Classify multiple queries"""
if not queries_text.strip():
return "⚠️ Please enter queries (one per line)", ""
queries = [q.strip() for q in queries_text.split('\n') if q.strip()]
if not queries:
return "⚠️ No valid queries found", ""
results = []
for query in queries:
result = classify_domain(query, temperature)
results.append({
"query": query,
"domain": result['domain'],
"confidence": result['confidence'],
"is_multi_domain": result['is_multi_domain'],
"is_ambiguous": result['is_ambiguous']
})
# Create summary
summary = f"## πŸ“Š Batch Classification Results\n\n"
summary += f"**Total Queries:** {len(queries)}\n\n"
# Domain distribution
domain_counts = {}
for r in results:
domain = r['domain']
domain_counts[domain] = domain_counts.get(domain, 0) + 1
summary += "### Domain Distribution\n\n"
for domain, count in sorted(domain_counts.items(), key=lambda x: x[1], reverse=True):
percentage = (count / len(queries)) * 100
summary += f"- **{domain}**: {count} ({percentage:.1f}%)\n"
# Detailed results table
table = "\n\n## πŸ“‹ Detailed Results\n\n"
table += "| # | Query | Domain | Confidence |\n"
table += "|---|-------|--------|------------|\n"
for i, r in enumerate(results, 1):
query_short = r['query'][:50] + "..." if len(r['query']) > 50 else r['query']
flags = ""
if r['domain'] in PERFECT_DOMAINS:
flags += " ⭐"
if r['is_multi_domain']:
flags += " πŸ”€"
if r['is_ambiguous']:
flags += " ❓"
table += f"| {i} | {query_short} | {r['domain']}{flags} | {r['confidence']} |\n"
json_output = json.dumps(results, indent=2)
return summary + table, json_output
def get_domain_info():
"""Get information about all domains"""
info = "# πŸ“š Domain Information\n\n"
info += "## All 16 Supported Domains\n\n"
for domain, description in DOMAINS.items():
info += f"### {domain}\n"
info += f"**Description:** {description}\n\n"
if domain in PERFECT_DOMAINS:
info += "**Accuracy:** ⭐ 100% (Perfect)\n\n"
elif domain == "multi_domain":
info += "**Accuracy:** 71% F1\n\n"
else:
info += "**Accuracy:** 93-99% F1\n\n"
info += "\n## Legend\n\n"
info += "- ⭐ Perfect Domain (100% precision & recall)\n"
info += "- πŸ”€ Multi-Domain (queries spanning multiple areas)\n"
info += "- ❓ Ambiguous (unclear or multi-interpretation)\n"
return info
# ============================================================
# GRADIO INTERFACE
# ============================================================
def create_interface():
"""Create Gradio interface"""
with gr.Blocks(
title="Phi-3 Domain Classifier (98.26%)",
theme=gr.themes.Soft()
) as demo:
gr.Markdown("""
# πŸ† Phi-3 Domain Classification Interface (98.26% Accuracy)
Test your domain classifier with **98.26% accuracy**!
This model classifies text into 16 domains with **9 perfect domains** (100% accuracy).
⚠️ **Note:** Running on free tier - first query may take 30-60 seconds to load model into memory.
""")
with gr.Tabs():
# TAB 1: SINGLE QUERY
with gr.Tab("🎯 Single Query"):
gr.Markdown("### Classify a single query")
with gr.Row():
with gr.Column(scale=2):
single_query = gr.Textbox(
label="Enter your query",
placeholder="e.g., Write a Python function to sort a list",
lines=3
)
single_temp = gr.Slider(
minimum=0.01,
maximum=1.0,
value=0.1,
step=0.01,
label="Temperature (lower = more deterministic)"
)
single_classify_btn = gr.Button("πŸš€ Classify", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### Quick Examples")
gr.Examples(
examples=[
["Write a Python function to sort a list"],
["What are the symptoms of diabetes?"],
["Explain quantum entanglement"],
["How to implement OAuth2 authentication?"],
["What is the Pythagorean theorem?"],
["Best practices for React development"],
["Legal requirements for starting a business"],
["Write a poem about nature"],
],
inputs=single_query
)
gr.Markdown("---")
single_result = gr.Markdown(label="Result")
with gr.Row():
with gr.Column():
single_domain_info = gr.Markdown(label="Domain Info")
with gr.Column():
single_details = gr.Markdown(label="Details")
with gr.Accordion("πŸ” Raw JSON Output", open=False):
single_json = gr.Code(language="json", label="JSON")
single_classify_btn.click(
fn=classify_single_query,
inputs=[single_query, single_temp],
outputs=[single_result, single_domain_info, single_details, single_json]
)
# TAB 2: BATCH QUERIES
with gr.Tab("πŸ“Š Batch Classification"):
gr.Markdown("### Classify multiple queries at once (one per line)")
batch_queries = gr.Textbox(
label="Enter queries (one per line)",
placeholder="Write a Python function\nWhat is diabetes?\nExplain quantum mechanics",
lines=10
)
batch_temp = gr.Slider(
minimum=0.01,
maximum=1.0,
value=0.1,
step=0.01,
label="Temperature"
)
with gr.Row():
batch_classify_btn = gr.Button("πŸš€ Classify All", variant="primary")
batch_clear_btn = gr.ClearButton([batch_queries])
gr.Markdown("---")
batch_result = gr.Markdown(label="Summary")
with gr.Accordion("πŸ” JSON Output", open=False):
batch_json = gr.Code(language="json", label="JSON")
batch_classify_btn.click(
fn=classify_batch_queries,
inputs=[batch_queries, batch_temp],
outputs=[batch_result, batch_json]
)
# TAB 3: DOMAIN INFO
with gr.Tab("πŸ“š Domain Information"):
domain_info_display = gr.Markdown(value=get_domain_info())
# Footer
gr.Markdown("""
---
### πŸ† Model Information
- **Model:** Phi-3-mini-4k-instruct + LoRA
- **Repository:** ovinduG/phi3-domain-classifier-98.26
- **Accuracy:** 98.26%
- **Perfect Domains:** 9/16 (100% accuracy)
- **Training:** 25 epochs, LoRA rank 32
**Legend:**
- ⭐ Perfect Domain (100% accuracy)
- πŸ”€ Multi-Domain Query
- ❓ Ambiguous Query
---
⚠️ **Performance Note:** Running on free tier CPU with disk offloading.
First query takes 30-60s. Upgrade to GPU for 2-3s response time.
""")
return demo
# ============================================================
# MAIN
# ============================================================
if __name__ == "__main__":
print("\nπŸš€ Creating Gradio interface...")
demo = create_interface()
print("βœ… Interface created!")
print("\n🌐 Launching...")
demo.queue() # Enable queue for better performance
demo.launch() # Simple launch for Spaces