import os import logging import torch from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import json import re # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelManager: def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False self.load_model() def load_model(self): """Load the model and tokenizer""" try: logger.info("Starting model loading...") # Check if CUDA is available if torch.cuda.is_available(): torch.cuda.set_device(0) self.device = "cuda:0" else: self.device = "cpu" logger.info(f"Using device: {self.device}") if self.device == "cuda:0": logger.info(f"GPU: {torch.cuda.get_device_name()}") logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") # Get HF token from environment hf_token = os.getenv("HF_TOKEN") logger.info("Loading Llama-3.1-8B-Instruct model...") base_model_name = "meta-llama/Llama-3.1-8B-Instruct" self.tokenizer = AutoTokenizer.from_pretrained( base_model_name, use_fast=True, trust_remote_code=True, token=hf_token ) self.model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32, device_map="auto" if self.device == "cuda:0" else None, trust_remote_code=True, token=hf_token, attn_implementation="eager" # Use eager attention (compatible) ) # Set pad token if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model_loaded = True logger.info("βœ… Model loaded successfully!") except Exception as e: logger.error(f"❌ Error loading model: {str(e)}") self.model_loaded = False def generate_response(prompt, temperature=0.8): """ZERO TRUNCATION GENERATION - Never cut anything!""" global model_manager if not model_manager or not model_manager.model_loaded: return "Model not loaded" try: # Detect CoT requests is_cot = any(phrase in prompt.lower() for phrase in [ "return exactly this json array", "chain of thinking", "verbatim" ]) logger.info(f"🎯 Request type: {'CoT' if is_cot else 'Standard'}") # Simple system message if is_cot: system = "You are an expert at generating JSON training data exactly as requested." else: system = "You are a helpful AI assistant." # Format prompt formatted = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> {system} <|eot_id|><|start_header_id|>user<|end_header_id|> {prompt} <|eot_id|><|start_header_id|>assistant<|end_header_id|> """ # Optimized token limits for speed if is_cot: max_new = 1500 # Reduced for speed min_new = 400 # Reduced minimum else: max_new = 800 # Significantly reduced for speed min_new = 50 # Lower minimum max_input = 6000 # Safe input limit logger.info(f"πŸ”’ Token allocation: Input≀{max_input}, Output={min_new}-{max_new}") # Tokenize inputs = model_manager.tokenizer( formatted, return_tensors="pt", truncation=True, max_length=max_input ) # Move to device if model_manager.device == "cuda:0": inputs = {k: v.to(next(model_manager.model.parameters()).device) for k, v in inputs.items()} logger.info("πŸš€ Starting generation...") # Generate with generous parameters with torch.no_grad(): outputs = model_manager.model.generate( **inputs, max_new_tokens=max_new, min_new_tokens=min_new, temperature=temperature, top_p=0.9, do_sample=True, num_beams=1, # Greedy search for speed pad_token_id=model_manager.tokenizer.eos_token_id, early_stopping=True, # Enable early stopping for speed repetition_penalty=1.1, use_cache=True ) # Decode the COMPLETE response full_response = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True) logger.info(f"πŸ“ Full response length: {len(full_response)} chars") logger.info(f"πŸ“ Response preview: {full_response[:200]}...") # ZERO TRUNCATION EXTRACTION - Find content intelligently but never cut response = full_response # Look for the assistant response marker assistant_marker = "<|start_header_id|>assistant<|end_header_id|>" if assistant_marker in full_response: # Find the position after the marker marker_pos = full_response.find(assistant_marker) if marker_pos != -1: # Start after the marker + some whitespace start_pos = marker_pos + len(assistant_marker) # Skip any immediate whitespace/newlines while start_pos < len(full_response) and full_response[start_pos] in ' \n\r\t': start_pos += 1 if start_pos < len(full_response): response = full_response[start_pos:] logger.info(f"βœ‚οΈ Extracted after assistant marker: {len(response)} chars") else: logger.info("πŸ”„ Marker found but no content after, using full response") else: logger.info("πŸ”„ Marker search failed, using full response") else: logger.info("πŸ”„ No assistant marker found, using full response") # For CoT, if we have a JSON array, extract it cleanly if is_cot and '[' in response and ']' in response: # Find the outermost JSON array first_bracket = response.find('[') last_bracket = response.rfind(']') if first_bracket != -1 and last_bracket != -1 and last_bracket > first_bracket: json_candidate = response[first_bracket:last_bracket+1] # Validate it contains the expected structure if '"user"' in json_candidate and '"assistant"' in json_candidate: # Count the objects to make sure we have multiple items user_count = json_candidate.count('"user"') if user_count >= 2: # Should have at least 2 user/assistant pairs response = json_candidate logger.info(f"🎯 Extracted JSON array with {user_count} items: {len(response)} chars") else: logger.info(f"⚠️ JSON array has only {user_count} items, using full response") else: logger.info("⚠️ JSON candidate failed validation, using full response") # Final response response = response.strip() logger.info(f"βœ… FINAL response: {len(response)} chars") logger.info(f"🎬 Starts with: {response[:150]}...") logger.info(f"🎭 Ends with: ...{response[-150:]}") return response except Exception as e: logger.error(f"πŸ’₯ Generation error: {e}") return f"Error: {e}" # Initialize model ONCE model_manager = ModelManager() def api_respond(message, history_str, temperature, json_mode, template): """ZERO TRUNCATION API - Pure content, no wrappers""" try: logger.info(f"πŸ“¨ API Request: {len(message)} chars, temp={temperature}") response = generate_response(message, temperature) logger.info(f"πŸ“€ API Response: {len(response)} chars") return response except Exception as e: logger.error(f"πŸ’₯ API Error: {e}") return f"Error: {e}" # BULLETPROOF GRADIO INTERFACE demo = gr.Interface( fn=api_respond, inputs=[ gr.Textbox(label="Message", lines=8, placeholder="Enter your prompt here..."), gr.Textbox(label="History", value="[]", visible=False), gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature"), gr.Textbox(label="JSON Mode", value="", visible=False), gr.Textbox(label="Template", value="", visible=False) ], outputs=gr.Textbox(label="Response", lines=20, max_lines=50), title="🎯 Question Generation API - ZERO TRUNCATION", description="Rebuilt from scratch with ZERO text cutting. Generates complete responses every time.", api_name="respond" ) if __name__ == "__main__": # Enable queue with concurrency limit of 10 demo.queue( default_concurrency_limit=10, # Handle 10 concurrent requests max_size=100 # Allow up to 100 requests in queue ).launch(server_name="0.0.0.0", server_port=7860, share=False)