Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import json | |
| import os | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from datetime import datetime | |
| # ============================================================ | |
| # CONFIGURATION | |
| # ============================================================ | |
| HF_MODEL_REPO = "ovinduG/phi3-domain-classifier-98.26" # Your LoRA adapters | |
| BASE_MODEL = "microsoft/Phi-3-mini-4k-instruct" # Base model + tokenizer | |
| # Memory optimization | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | |
| # Domain information | |
| DOMAINS = { | |
| "coding": "Programming and software development", | |
| "api_generation": "API design and implementation", | |
| "mathematics": "Mathematical problems and concepts", | |
| "data_analysis": "Data science and analytics", | |
| "science": "Scientific queries and research", | |
| "medicine": "Medical and healthcare topics", | |
| "business": "Business and commerce", | |
| "law": "Legal matters and regulations", | |
| "technology": "Tech industry and products", | |
| "literature": "Books, writing, poetry", | |
| "creative_content": "Art, music, creative work", | |
| "education": "Learning and teaching", | |
| "general_knowledge": "General information", | |
| "ambiguous": "Unclear or multi-interpretation queries", | |
| "sensitive": "Sensitive topics requiring care", | |
| "multi_domain": "Cross-domain queries" | |
| } | |
| PERFECT_DOMAINS = [ | |
| "ambiguous", "creative_content", "education", "law", | |
| "literature", "medicine", "science", "sensitive", "technology" | |
| ] | |
| # ============================================================ | |
| # MODEL LOADING WITH ALL FIXES | |
| # ============================================================ | |
| print("=" * 80) | |
| print("π Loading Phi-3 Domain Classifier...") | |
| print("=" * 80) | |
| try: | |
| # Load tokenizer from base model (fixes corruption issue) | |
| print("π₯ Loading tokenizer from Phi-3 base model...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| BASE_MODEL, | |
| trust_remote_code=True | |
| ) | |
| print("β Tokenizer loaded successfully") | |
| # Load base model with disk offloading for low memory | |
| print("π₯ Loading Phi-3 base model (this may take a minute)...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| offload_folder="offload", | |
| offload_state_dict=True | |
| ) | |
| print("β Base model loaded successfully") | |
| # Load LoRA adapters with offloading | |
| print("π₯ Loading LoRA adapters (98.26% accuracy)...") | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| HF_MODEL_REPO, | |
| torch_dtype=torch.float16, | |
| offload_folder="offload" | |
| ) | |
| model.config.use_cache = False | |
| print("β LoRA adapters loaded successfully") | |
| print("=" * 80) | |
| print("π Model ready! 98.26% accuracy achieved!") | |
| print("=" * 80) | |
| except Exception as e: | |
| print("=" * 80) | |
| print(f"β ERROR loading model: {e}") | |
| print("=" * 80) | |
| raise | |
| # ============================================================ | |
| # CLASSIFICATION FUNCTION - FULLY FIXED | |
| # ============================================================ | |
| def classify_domain(text, temperature=0.1): | |
| """Classify a query into a domain - Fixed version""" | |
| try: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": "You are a domain classifier. Respond with JSON containing primary_domain and confidence." | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Classify: {text}" | |
| } | |
| ] | |
| # Apply chat template to get text first (compatibility fix) | |
| chat_text = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=False # Get text first, tokenize separately | |
| ) | |
| # Tokenize separately without return_dict (fixes compatibility) | |
| inputs = tokenizer( | |
| chat_text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| # Move to device | |
| input_ids = inputs["input_ids"].to(model.device) | |
| attention_mask = inputs.get("attention_mask", None) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(model.device) | |
| # Generate with fixed parameters | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=50, | |
| temperature=temperature, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| use_cache=True | |
| ) | |
| # Decode response | |
| response = tokenizer.decode( | |
| outputs[0][input_ids.shape[-1]:], | |
| skip_special_tokens=True | |
| ) | |
| # Parse JSON response | |
| try: | |
| response_clean = response.strip() | |
| if '```' in response_clean: | |
| response_clean = response_clean.split('```')[1] | |
| if response_clean.startswith('json'): | |
| response_clean = response_clean[4:] | |
| result = json.loads(response_clean.strip()) | |
| domain = result.get('primary_domain', 'unknown') | |
| confidence = result.get('confidence', 'medium') | |
| return { | |
| "domain": domain, | |
| "confidence": confidence, | |
| "description": DOMAINS.get(domain, "Unknown domain"), | |
| "is_perfect_domain": domain in PERFECT_DOMAINS, | |
| "is_multi_domain": domain == "multi_domain", | |
| "is_ambiguous": domain == "ambiguous", | |
| "raw_response": response | |
| } | |
| except: | |
| # Fallback: search for domain names in response | |
| for domain in DOMAINS.keys(): | |
| if domain in response.lower(): | |
| return { | |
| "domain": domain, | |
| "confidence": "medium", | |
| "description": DOMAINS.get(domain, "Unknown domain"), | |
| "is_perfect_domain": domain in PERFECT_DOMAINS, | |
| "is_multi_domain": domain == "multi_domain", | |
| "is_ambiguous": domain == "ambiguous", | |
| "raw_response": response | |
| } | |
| return { | |
| "domain": "unknown", | |
| "confidence": "low", | |
| "description": "Failed to classify", | |
| "is_perfect_domain": False, | |
| "is_multi_domain": False, | |
| "is_ambiguous": False, | |
| "raw_response": response | |
| } | |
| except Exception as e: | |
| return { | |
| "domain": "error", | |
| "confidence": "low", | |
| "description": f"Classification error: {str(e)}", | |
| "is_perfect_domain": False, | |
| "is_multi_domain": False, | |
| "is_ambiguous": False, | |
| "raw_response": str(e) | |
| } | |
| # ============================================================ | |
| # GRADIO INTERFACE FUNCTIONS | |
| # ============================================================ | |
| def classify_single_query(query, temperature): | |
| """Classify a single query""" | |
| if not query.strip(): | |
| return "β οΈ Please enter a query", "", "", "" | |
| result = classify_domain(query, temperature) | |
| # Format main result | |
| domain = result['domain'] | |
| confidence = result['confidence'] | |
| description = result['description'] | |
| main_output = f"## π― Classification Result\n\n" | |
| main_output += f"**Domain:** `{domain}`\n\n" | |
| main_output += f"**Description:** {description}\n\n" | |
| main_output += f"**Confidence:** {confidence.upper()}\n\n" | |
| # Add badges | |
| badges = [] | |
| if result['is_perfect_domain']: | |
| badges.append("β **Perfect Domain** (100% accuracy)") | |
| if result['is_multi_domain']: | |
| badges.append("π **Multi-Domain Query**") | |
| if result['is_ambiguous']: | |
| badges.append("β **Ambiguous Query**") | |
| if badges: | |
| main_output += "### Status\n" | |
| for badge in badges: | |
| main_output += f"- {badge}\n" | |
| # Domain info | |
| domain_info = f"### π Domain Information\n\n" | |
| domain_info += f"**Category:** {domain}\n\n" | |
| domain_info += f"**Description:** {description}\n\n" | |
| if domain in PERFECT_DOMAINS: | |
| domain_info += "**Accuracy:** β 100% (Perfect)\n\n" | |
| elif domain == "multi_domain": | |
| domain_info += "**Accuracy:** 71% F1\n\n" | |
| else: | |
| domain_info += "**Accuracy:** 93-99% F1\n\n" | |
| # Classification details | |
| details = f"### π Classification Details\n\n" | |
| details += f"**Query:** {query}\n\n" | |
| details += f"**Predicted Domain:** {domain}\n\n" | |
| details += f"**Confidence Level:** {confidence}\n\n" | |
| # JSON output | |
| json_output = json.dumps(result, indent=2) | |
| return main_output, domain_info, details, json_output | |
| def classify_batch_queries(queries_text, temperature): | |
| """Classify multiple queries""" | |
| if not queries_text.strip(): | |
| return "β οΈ Please enter queries (one per line)", "" | |
| queries = [q.strip() for q in queries_text.split('\n') if q.strip()] | |
| if not queries: | |
| return "β οΈ No valid queries found", "" | |
| results = [] | |
| for query in queries: | |
| result = classify_domain(query, temperature) | |
| results.append({ | |
| "query": query, | |
| "domain": result['domain'], | |
| "confidence": result['confidence'], | |
| "is_multi_domain": result['is_multi_domain'], | |
| "is_ambiguous": result['is_ambiguous'] | |
| }) | |
| # Create summary | |
| summary = f"## π Batch Classification Results\n\n" | |
| summary += f"**Total Queries:** {len(queries)}\n\n" | |
| # Domain distribution | |
| domain_counts = {} | |
| for r in results: | |
| domain = r['domain'] | |
| domain_counts[domain] = domain_counts.get(domain, 0) + 1 | |
| summary += "### Domain Distribution\n\n" | |
| for domain, count in sorted(domain_counts.items(), key=lambda x: x[1], reverse=True): | |
| percentage = (count / len(queries)) * 100 | |
| summary += f"- **{domain}**: {count} ({percentage:.1f}%)\n" | |
| # Detailed results table | |
| table = "\n\n## π Detailed Results\n\n" | |
| table += "| # | Query | Domain | Confidence |\n" | |
| table += "|---|-------|--------|------------|\n" | |
| for i, r in enumerate(results, 1): | |
| query_short = r['query'][:50] + "..." if len(r['query']) > 50 else r['query'] | |
| flags = "" | |
| if r['domain'] in PERFECT_DOMAINS: | |
| flags += " β" | |
| if r['is_multi_domain']: | |
| flags += " π" | |
| if r['is_ambiguous']: | |
| flags += " β" | |
| table += f"| {i} | {query_short} | {r['domain']}{flags} | {r['confidence']} |\n" | |
| json_output = json.dumps(results, indent=2) | |
| return summary + table, json_output | |
| def get_domain_info(): | |
| """Get information about all domains""" | |
| info = "# π Domain Information\n\n" | |
| info += "## All 16 Supported Domains\n\n" | |
| for domain, description in DOMAINS.items(): | |
| info += f"### {domain}\n" | |
| info += f"**Description:** {description}\n\n" | |
| if domain in PERFECT_DOMAINS: | |
| info += "**Accuracy:** β 100% (Perfect)\n\n" | |
| elif domain == "multi_domain": | |
| info += "**Accuracy:** 71% F1\n\n" | |
| else: | |
| info += "**Accuracy:** 93-99% F1\n\n" | |
| info += "\n## Legend\n\n" | |
| info += "- β Perfect Domain (100% precision & recall)\n" | |
| info += "- π Multi-Domain (queries spanning multiple areas)\n" | |
| info += "- β Ambiguous (unclear or multi-interpretation)\n" | |
| return info | |
| # ============================================================ | |
| # GRADIO INTERFACE | |
| # ============================================================ | |
| def create_interface(): | |
| """Create Gradio interface""" | |
| with gr.Blocks( | |
| title="Phi-3 Domain Classifier (98.26%)", | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π Phi-3 Domain Classification Interface (98.26% Accuracy) | |
| Test your domain classifier with **98.26% accuracy**! | |
| This model classifies text into 16 domains with **9 perfect domains** (100% accuracy). | |
| β οΈ **Note:** Running on free tier - first query may take 30-60 seconds to load model into memory. | |
| """) | |
| with gr.Tabs(): | |
| # TAB 1: SINGLE QUERY | |
| with gr.Tab("π― Single Query"): | |
| gr.Markdown("### Classify a single query") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| single_query = gr.Textbox( | |
| label="Enter your query", | |
| placeholder="e.g., Write a Python function to sort a list", | |
| lines=3 | |
| ) | |
| single_temp = gr.Slider( | |
| minimum=0.01, | |
| maximum=1.0, | |
| value=0.1, | |
| step=0.01, | |
| label="Temperature (lower = more deterministic)" | |
| ) | |
| single_classify_btn = gr.Button("π Classify", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Quick Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["Write a Python function to sort a list"], | |
| ["What are the symptoms of diabetes?"], | |
| ["Explain quantum entanglement"], | |
| ["How to implement OAuth2 authentication?"], | |
| ["What is the Pythagorean theorem?"], | |
| ["Best practices for React development"], | |
| ["Legal requirements for starting a business"], | |
| ["Write a poem about nature"], | |
| ], | |
| inputs=single_query | |
| ) | |
| gr.Markdown("---") | |
| single_result = gr.Markdown(label="Result") | |
| with gr.Row(): | |
| with gr.Column(): | |
| single_domain_info = gr.Markdown(label="Domain Info") | |
| with gr.Column(): | |
| single_details = gr.Markdown(label="Details") | |
| with gr.Accordion("π Raw JSON Output", open=False): | |
| single_json = gr.Code(language="json", label="JSON") | |
| single_classify_btn.click( | |
| fn=classify_single_query, | |
| inputs=[single_query, single_temp], | |
| outputs=[single_result, single_domain_info, single_details, single_json] | |
| ) | |
| # TAB 2: BATCH QUERIES | |
| with gr.Tab("π Batch Classification"): | |
| gr.Markdown("### Classify multiple queries at once (one per line)") | |
| batch_queries = gr.Textbox( | |
| label="Enter queries (one per line)", | |
| placeholder="Write a Python function\nWhat is diabetes?\nExplain quantum mechanics", | |
| lines=10 | |
| ) | |
| batch_temp = gr.Slider( | |
| minimum=0.01, | |
| maximum=1.0, | |
| value=0.1, | |
| step=0.01, | |
| label="Temperature" | |
| ) | |
| with gr.Row(): | |
| batch_classify_btn = gr.Button("π Classify All", variant="primary") | |
| batch_clear_btn = gr.ClearButton([batch_queries]) | |
| gr.Markdown("---") | |
| batch_result = gr.Markdown(label="Summary") | |
| with gr.Accordion("π JSON Output", open=False): | |
| batch_json = gr.Code(language="json", label="JSON") | |
| batch_classify_btn.click( | |
| fn=classify_batch_queries, | |
| inputs=[batch_queries, batch_temp], | |
| outputs=[batch_result, batch_json] | |
| ) | |
| # TAB 3: DOMAIN INFO | |
| with gr.Tab("π Domain Information"): | |
| domain_info_display = gr.Markdown(value=get_domain_info()) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| ### π Model Information | |
| - **Model:** Phi-3-mini-4k-instruct + LoRA | |
| - **Repository:** ovinduG/phi3-domain-classifier-98.26 | |
| - **Accuracy:** 98.26% | |
| - **Perfect Domains:** 9/16 (100% accuracy) | |
| - **Training:** 25 epochs, LoRA rank 32 | |
| **Legend:** | |
| - β Perfect Domain (100% accuracy) | |
| - π Multi-Domain Query | |
| - β Ambiguous Query | |
| --- | |
| β οΈ **Performance Note:** Running on free tier CPU with disk offloading. | |
| First query takes 30-60s. Upgrade to GPU for 2-3s response time. | |
| """) | |
| return demo | |
| # ============================================================ | |
| # MAIN | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| print("\nπ Creating Gradio interface...") | |
| demo = create_interface() | |
| print("β Interface created!") | |
| print("\nπ Launching...") | |
| demo.queue() # Enable queue for better performance | |
| demo.launch() # Simple launch for Spaces |