| |
| """ |
| Inference example for Polyjuice MBTI model from Hugging Face |
| |
| Since this is a custom Rust/PyTorch model, it cannot use HF Inference API. |
| Users need to download the model files and use the Rust binary for inference. |
| |
| This script shows how to download and prepare for inference. |
| """ |
|
|
| from huggingface_hub import hf_hub_download |
| import os |
| import subprocess |
| import json |
|
|
| def download_model(repo_id="ElderRyan/polyjuice", cache_dir="./model_cache"): |
| """Download model files from Hugging Face""" |
| |
| print(f"Downloading model from {repo_id}...") |
| |
| files_to_download = [ |
| "mlp_weights_multitask.pt", |
| "tfidf_vectorizer_multitask.json", |
| "config.json" |
| ] |
| |
| downloaded_paths = {} |
| |
| for filename in files_to_download: |
| print(f" Downloading {filename}...") |
| path = hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| cache_dir=cache_dir |
| ) |
| downloaded_paths[filename] = path |
| print(f" β Saved to: {path}") |
| |
| return downloaded_paths |
|
|
| def show_usage(): |
| """Show how to use the downloaded model""" |
| |
| print("\n" + "="*60) |
| print("MODEL DOWNLOADED SUCCESSFULLY") |
| print("="*60) |
| print("\nThis is a Rust-based model. To use it:") |
| print("\n1. Clone the Rust project:") |
| print(" git clone https://github.com/RyanKung/polyjuice") |
| print(" cd polyjuice") |
| print("\n2. Copy downloaded model files:") |
| print(" mkdir -p models") |
| print(" cp <downloaded_path>/mlp_weights_multitask.pt models/") |
| print(" cp <downloaded_path>/tfidf_vectorizer_multitask.json models/") |
| print("\n3. Build and run:") |
| print(" cargo build --release") |
| print(" ./target/release/psycial hybrid predict \"Your text here\"") |
| print("\n" + "="*60) |
| print("\nAlternatively, use the web interface at:") |
| print("https://polyjuice.0xbase.ai") |
| print("="*60 + "\n") |
|
|
| def main(): |
| print("\nβββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ") |
| print("β Polyjuice MBTI Classifier - Model Download β") |
| print("βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n") |
| |
| |
| paths = download_model() |
| |
| |
| with open(paths["config.json"], 'r') as f: |
| config = json.load(f) |
| |
| print("\n" + "="*60) |
| print("MODEL INFORMATION") |
| print("="*60) |
| print(f"Model Type: {config.get('model_type', 'N/A')}") |
| print(f"Input Features: {config.get('input_features', 'N/A')}") |
| print(f"Architecture: {config.get('architecture', 'N/A')}") |
| print(f"\nAccuracy:") |
| acc = config.get('accuracy', {}) |
| print(f" Overall: {acc.get('overall', 'N/A')}%") |
| print(f" E/I: {acc.get('e_i', 'N/A')}%") |
| print(f" S/N: {acc.get('s_n', 'N/A')}%") |
| print(f" T/F: {acc.get('t_f', 'N/A')}%") |
| print(f" J/P: {acc.get('j_p', 'N/A')}%") |
| |
| show_usage() |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|