import os import logging import torch from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelManager: def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False self.load_model() def load_model(self): """Load the model and tokenizer""" try: logger.info("Starting model loading...") # Check if CUDA is available if torch.cuda.is_available(): torch.cuda.set_device(0) self.device = "cuda:0" else: self.device = "cpu" logger.info(f"Using device: {self.device}") if self.device == "cuda:0": logger.info(f"GPU: {torch.cuda.get_device_name()}") logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") # Get HF token from environment hf_token = os.getenv("HF_TOKEN") logger.info("Loading Llama-3.1-8B-Instruct model...") base_model_name = "meta-llama/Llama-3.1-8B-Instruct" self.tokenizer = AutoTokenizer.from_pretrained( base_model_name, use_fast=True, trust_remote_code=True, token=hf_token ) self.model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32, device_map={"": 0} if self.device == "cuda:0" else None, trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, token=hf_token ) if self.device == "cuda:0": self.model = self.model.to(self.device) self.model_loaded = True logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Error loading model: {str(e)}") self.model_loaded = False # Initialize model manager model_manager = ModelManager() def generate_response(prompt, temperature=0.8): """Simple function to generate a response from a prompt""" if not model_manager.model_loaded: return "Model not loaded yet. Please wait..." try: # Create the Llama-3.1 chat format formatted_prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> {prompt} <|eot_id|><|start_header_id|>assistant<|end_header_id|> """ # Determine context window and USE ABSOLUTE MAXIMUM try: max_ctx = getattr(model_manager.model.config, "max_position_embeddings", 131072) # Llama 3.1 supports up to 131k except Exception: max_ctx = 131072 # Use maximum possible logger.info(f"Model max context: {max_ctx} tokens") # Detect if this is a Chain of Thinking request is_cot_request = ("chain-of-thinking" in prompt.lower() or "chain of thinking" in prompt.lower() or "Return exactly this JSON array" in prompt or ("verbatim" in prompt.lower() and "json array" in prompt.lower())) # MAXIMIZE GENERATION TOKENS - use most of context for generation if is_cot_request: # For CoT, use MAXIMUM possible generation tokens gen_max_new_tokens = 16384 # Very high limit for complete responses min_tokens = 2000 # High minimum to force complete generation # Allow most of context for input allowed_input_tokens = max_ctx - gen_max_new_tokens - 100 # Small safety buffer logger.info(f"CoT REQUEST - MAXIMIZED: min_tokens={min_tokens}, max_new_tokens={gen_max_new_tokens}, input_limit={allowed_input_tokens}") else: # Standard requests gen_max_new_tokens = 8192 min_tokens = 200 allowed_input_tokens = max_ctx - gen_max_new_tokens - 100 # Tokenize the input with safe truncation inputs = model_manager.tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=allowed_input_tokens ) # Move inputs to the same device as the model if model_manager.device == "cuda:0": model_device = next(model_manager.model.parameters()).device inputs = {k: v.to(model_device) for k, v in inputs.items()} # Generate response with MAXIMUM settings with torch.no_grad(): outputs = model_manager.model.generate( **inputs, max_new_tokens=gen_max_new_tokens, min_new_tokens=min_tokens, temperature=temperature, top_p=0.95, do_sample=True, num_beams=1, pad_token_id=model_manager.tokenizer.eos_token_id, eos_token_id=model_manager.tokenizer.eos_token_id, early_stopping=False, # Never stop early repetition_penalty=1.05, no_repeat_ngram_size=0, length_penalty=1.0, # Force generation to continue use_cache=True ) # Decode the response generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True) # Log generation details for debugging input_length = inputs['input_ids'].shape[1] output_length = outputs[0].shape[0] generated_length = output_length - input_length logger.info(f"Generation stats - Input: {input_length} tokens, Generated: {generated_length} tokens, Min required: {min_tokens}") if generated_length < min_tokens: logger.warning(f"Generated {generated_length} tokens but minimum was {min_tokens} - response may be truncated") # Post-decode guard: if a top-level JSON array closes, trim to the first full array # This helps prevent trailing prose like 'assistant' or 'Message'. try: # Track both bracket and brace depth to find first complete JSON structure bracket_depth = 0 # [ ] brace_depth = 0 # { } in_string = False escape_next = False start_idx = None end_idx = None for i, ch in enumerate(generated_text): # Handle string escaping if escape_next: escape_next = False continue if ch == '\\': escape_next = True continue # Track if we're inside a string if ch == '"' and not escape_next: in_string = not in_string continue # Only count brackets/braces outside of strings if not in_string: if ch == '[': if bracket_depth == 0 and brace_depth == 0 and start_idx is None: start_idx = i bracket_depth += 1 elif ch == ']': bracket_depth = max(0, bracket_depth - 1) if bracket_depth == 0 and brace_depth == 0 and start_idx is not None: end_idx = i break elif ch == '{': brace_depth += 1 elif ch == '}': brace_depth = max(0, brace_depth - 1) if start_idx is not None and end_idx is not None and end_idx > start_idx: # Extract just the complete JSON array json_text = generated_text[start_idx:end_idx+1] logger.info(f"Extracted complete JSON array of length {len(json_text)}") generated_text = json_text elif start_idx is not None: # Found start but no end - response was truncated logger.warning("JSON array started but never closed - response truncated") # Try to extract what we have and let the client handle it generated_text = generated_text[start_idx:] except Exception as e: logger.warning(f"Error in JSON extraction: {e}") pass # Extract just the assistant's response if "<|start_header_id|>assistant<|end_header_id|>" in generated_text: response = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip() else: # Better fallback: look for the start of actual content (JSON or text) import re # Look for JSON array or object start json_match = re.search(r'(\[|\{)', generated_text) if json_match and json_match.start() > len(formatted_prompt) // 2: response = generated_text[json_match.start():].strip() else: # Look for the end of the prompt pattern prompt_end_patterns = [ "<|end_header_id|>", "<|eot_id|>", "assistant", "\n\n" ] response = generated_text for pattern in prompt_end_patterns: if pattern in generated_text: parts = generated_text.split(pattern) if len(parts) > 1: # Take the last substantial part candidate = parts[-1].strip() if len(candidate) > 20: # Ensure it's not too short response = candidate break # Ultimate fallback - just return everything after a reasonable point if response == generated_text: # Skip approximately the prompt length but be conservative skip_chars = min(len(formatted_prompt) // 2, len(generated_text) // 3) response = generated_text[skip_chars:].strip() logger.info(f"Generated response length: {len(response)} characters") return response except Exception as e: logger.error(f"Error generating response: {str(e)}") return f"Error: {str(e)}" def respond(message, history, temperature): """Gradio interface function for chat""" response = generate_response(message, temperature) # Update history history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": response}) return history, "" # Create the Gradio interface with gr.Blocks(title="Question Generation API") as demo: gr.Markdown("# Simple LLM API") gr.Markdown("Send a prompt and get a response. No templates, just direct model interaction.") with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot( label="Chat", type="messages", height=400 ) msg = gr.Textbox( label="Message", placeholder="Enter your prompt here...", lines=3 ) with gr.Row(): submit = gr.Button("Send", variant="primary") clear = gr.Button("Clear") with gr.Column(scale=1): temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature", info="Higher = more creative" ) gr.Markdown(""" ### API Usage This model accepts any prompt and returns a response. For JSON responses, include instructions in your prompt like: - "Return as a JSON array" - "Format as JSON" - "List as JSON" The model will follow your instructions. """) # Set up event handlers submit.click(respond, [msg, chatbot, temperature], [chatbot, msg]) msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg]) clear.click(lambda: ([], ""), outputs=[chatbot, msg]) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False )