Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| import os | |
| import time | |
| import logging | |
| from datetime import datetime | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| datefmt='%H:%M:%S' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Configuration for CPU optimization | |
| class Config: | |
| MODEL_PATH = "navidfalah/3ai" # Your fine-tuned model | |
| BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.1" | |
| ADAPTER_PATH = "./model" | |
| MAX_NEW_TOKENS = 50 # Very short for CPU speed | |
| TEMPERATURE = 0.7 | |
| TOP_P = 0.9 | |
| MAX_INPUT_LENGTH = 128 # Short input for speed | |
| USE_8BIT = True # Use 8-bit quantization for CPU | |
| # Global variables | |
| model = None | |
| tokenizer = None | |
| model_load_time = None | |
| def log_time(start_time, operation): | |
| """Log time taken for an operation.""" | |
| elapsed = time.time() - start_time | |
| logger.info(f"{operation} took {elapsed:.2f} seconds") | |
| return elapsed | |
| def load_model_cpu_optimized(): | |
| """Load your fine-tuned model optimized for CPU inference.""" | |
| global model, tokenizer, model_load_time | |
| if model is not None and tokenizer is not None: | |
| logger.info("Model already loaded, using cached version") | |
| return model, tokenizer | |
| total_start = time.time() | |
| try: | |
| logger.info("Starting to load fine-tuned Mistral model for CPU...") | |
| logger.warning("Note: 7B model on CPU will be slow. First load may take 2-5 minutes.") | |
| # Load tokenizer | |
| start = time.time() | |
| logger.info("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| log_time(start, "Tokenizer loading") | |
| # CPU-optimized loading | |
| start = time.time() | |
| logger.info("Loading base Mistral model with CPU optimizations...") | |
| if Config.USE_8BIT: | |
| logger.info("Using 8-bit quantization for CPU...") | |
| # Try 8-bit quantization for CPU (experimental) | |
| try: | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_8bit_compute_dtype=torch.float16, | |
| bnb_8bit_use_double_quant=False, | |
| ) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| Config.BASE_MODEL, | |
| quantization_config=bnb_config, | |
| device_map={"": "cpu"}, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.float16 | |
| ) | |
| except: | |
| logger.warning("8-bit quantization failed, using float32...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| Config.BASE_MODEL, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| device_map="cpu" | |
| ) | |
| else: | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| Config.BASE_MODEL, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| device_map="cpu" | |
| ) | |
| log_time(start, "Base model loading") | |
| # Load your fine-tuned adapter | |
| start = time.time() | |
| logger.info("Loading fine-tuned adapter...") | |
| try: | |
| # Try loading from HuggingFace | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| Config.MODEL_PATH, | |
| is_trainable=False, | |
| torch_dtype=torch.float32 | |
| ) | |
| logger.info("✅ Loaded adapter from HuggingFace") | |
| except Exception as e: | |
| logger.warning(f"Could not load from HF: {e}") | |
| # Try local adapter | |
| if os.path.exists(Config.ADAPTER_PATH): | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| Config.ADAPTER_PATH, | |
| is_trainable=False, | |
| torch_dtype=torch.float32 | |
| ) | |
| logger.info("✅ Loaded adapter from local path") | |
| else: | |
| logger.error("No adapter found! Using base model only.") | |
| model = base_model | |
| log_time(start, "Adapter loading") | |
| # Optimize model for inference | |
| model.eval() | |
| # Try to enable CPU optimizations | |
| if hasattr(torch, 'set_num_threads'): | |
| torch.set_num_threads(os.cpu_count()) | |
| logger.info(f"Set PyTorch threads to {os.cpu_count()}") | |
| model_load_time = log_time(total_start, "Total model loading") | |
| logger.info(f"✅ Model ready! Total parameters: ~{sum(p.numel() for p in model.parameters()) / 1e9:.1f}B") | |
| return model, tokenizer | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, None | |
| def analyze_text(user_input, progress=gr.Progress()): | |
| """Analyze text with your fine-tuned model.""" | |
| start_time = time.time() | |
| if not user_input.strip(): | |
| return "Please enter some text to analyze.", "No input provided" | |
| logger.info(f"Starting analysis for input: {user_input[:50]}...") | |
| # Update progress | |
| progress(0.1, desc="Loading model (this may take 2-5 minutes on first run)...") | |
| # Load model with timing | |
| model_start = time.time() | |
| model, tokenizer = load_model_cpu_optimized() | |
| model_time = time.time() - model_start | |
| if model is None or tokenizer is None: | |
| return "Error: Could not load model.", f"Model loading failed after {model_time:.2f}s" | |
| progress(0.3, desc="Model loaded, preparing input...") | |
| try: | |
| # Format prompt for Mistral instruction format | |
| prompt = f"[INST] Analyze this life situation and provide brief satisfaction analysis: {user_input} [/INST]" | |
| logger.info(f"Prompt length: {len(prompt)} characters") | |
| # Tokenize with timing | |
| tokenize_start = time.time() | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=Config.MAX_INPUT_LENGTH, | |
| padding=True | |
| ) | |
| tokenize_time = log_time(tokenize_start, "Tokenization") | |
| progress(0.5, desc="Generating response (this may take 1-3 minutes on CPU)...") | |
| # Log input details | |
| input_ids = inputs['input_ids'] | |
| logger.info(f"Input tokens: {input_ids.shape[1]}") | |
| logger.info(f"Generating up to {Config.MAX_NEW_TOKENS} new tokens...") | |
| # Generate with aggressive CPU optimizations | |
| gen_start = time.time() | |
| with torch.no_grad(): | |
| # Use torch.cuda.amp.autocast for mixed precision even on CPU | |
| with torch.cpu.amp.autocast(enabled=True): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=Config.MAX_NEW_TOKENS, | |
| temperature=Config.TEMPERATURE, | |
| do_sample=True, | |
| top_k=50, # Limit sampling pool | |
| top_p=Config.TOP_P, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| early_stopping=True, | |
| num_beams=1, # No beam search | |
| use_cache=True, # KV cache | |
| repetition_penalty=1.1 | |
| ) | |
| gen_time = log_time(gen_start, "Generation") | |
| tokens_generated = outputs.shape[1] - input_ids.shape[1] | |
| tokens_per_second = tokens_generated / gen_time if gen_time > 0 else 0 | |
| logger.info(f"Generated {tokens_generated} tokens at {tokens_per_second:.2f} tokens/second") | |
| progress(0.8, desc="Decoding response...") | |
| # Decode with timing | |
| decode_start = time.time() | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| decode_time = log_time(decode_start, "Decoding") | |
| # Extract generated part | |
| if "[/INST]" in response: | |
| result = response.split("[/INST]")[-1].strip() | |
| else: | |
| result = response[len(prompt):].strip() | |
| if not result: | |
| result = "Analysis: Based on your input, I recommend focusing on balance across life domains." | |
| # Total time | |
| total_time = time.time() - start_time | |
| logger.info(f"✅ Total analysis time: {total_time:.2f}s") | |
| # Create detailed timing report | |
| timing_report = f"""### Performance Report | |
| **Model Loading:** | |
| - Time: {model_time:.2f}s {' (cached after first load)' if model_time < 1 else ''} | |
| **Generation Details:** | |
| - Tokenization: {tokenize_time:.2f}s | |
| - Generation: {gen_time:.2f}s | |
| - Decoding: {decode_time:.2f}s | |
| - **Total: {total_time:.2f}s** | |
| **Token Statistics:** | |
| - Input tokens: {input_ids.shape[1]} | |
| - Generated tokens: {tokens_generated} | |
| - Speed: {tokens_per_second:.2f} tokens/second | |
| **System Info:** | |
| - Model: Fine-tuned Mistral-7B | |
| - Device: CPU ({os.cpu_count()} cores) | |
| - Quantization: {'8-bit' if Config.USE_8BIT else 'Float32'} | |
| 💡 **Tips for faster response:** | |
| - Keep inputs under 50 words | |
| - First run is slowest (model loading) | |
| - Consider using GPU for 10-50x speedup | |
| """ | |
| progress(1.0, desc="Complete!") | |
| return result, timing_report | |
| except Exception as e: | |
| error_msg = f"Error during analysis: {str(e)}" | |
| logger.error(error_msg) | |
| total_time = time.time() - start_time | |
| return error_msg, f"Failed after {total_time:.2f}s\nError: {str(e)}" | |
| # Create optimized interface | |
| with gr.Blocks(title="Life Satisfaction Analysis", theme=gr.themes.Base()) as demo: | |
| gr.Markdown(""" | |
| # Life Satisfaction Analysis (CPU Mode) | |
| Using fine-tuned Mistral-7B model. ⚠️ **CPU inference is slow** - expect 2-5 minutes per analysis. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="Describe your situation", | |
| placeholder="Example: I'm stressed at work (3/10) but happy with family (8/10)...", | |
| lines=3, | |
| max_lines=5 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("🔍 Analyze", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| gr.Markdown(""" | |
| **⚡ Speed Tips:** | |
| - Keep input brief (< 50 words) | |
| - First analysis loads model (2-5 min) | |
| - Next analyses are faster (~1-2 min) | |
| """) | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="AI Analysis", | |
| lines=6, | |
| interactive=False | |
| ) | |
| timing_info = gr.Markdown( | |
| value="*Performance metrics will appear here*" | |
| ) | |
| # Quick examples | |
| gr.Examples( | |
| examples=[ | |
| "Work is stressful, health okay, finances tight", | |
| "Happy job but no work-life balance", | |
| "Good health and relationships, career stagnant" | |
| ], | |
| inputs=input_text, | |
| label="Quick Examples" | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=analyze_text, | |
| inputs=input_text, | |
| outputs=[output_text, timing_info] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", "*Performance metrics will appear here*"), | |
| outputs=[input_text, output_text, timing_info] | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("="*60) | |
| logger.info("Starting Life Satisfaction Analysis App") | |
| logger.info("="*60) | |
| logger.info(f"Model: {Config.MODEL_PATH}") | |
| logger.info(f"Base: {Config.BASE_MODEL}") | |
| logger.info(f"Device: CPU ({os.cpu_count()} cores)") | |
| logger.info(f"PyTorch: {torch.__version__}") | |
| logger.info(f"Max tokens: {Config.MAX_NEW_TOKENS}") | |
| logger.info("="*60) | |
| logger.info("⚠️ WARNING: 7B model on CPU is SLOW!") | |
| logger.info("First load: 2-5 minutes") | |
| logger.info("Per query: 1-3 minutes") | |
| logger.info("For faster inference, use GPU!") | |
| logger.info("="*60) | |
| # Optional: Pre-load model | |
| if False: # Set to True to pre-load | |
| logger.info("Pre-loading model (this will take 2-5 minutes)...") | |
| pre_start = time.time() | |
| load_model_cpu_optimized() | |
| logger.info(f"Model pre-loaded in {time.time() - pre_start:.2f}s") | |
| demo.queue() | |
| demo.launch() |