Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.model_loaded = False | |
| self.load_model() | |
| def load_model(self): | |
| """Load the model and tokenizer""" | |
| try: | |
| logger.info("Starting model loading...") | |
| # Check if CUDA is available | |
| if torch.cuda.is_available(): | |
| torch.cuda.set_device(0) | |
| self.device = "cuda:0" | |
| else: | |
| self.device = "cpu" | |
| logger.info(f"Using device: {self.device}") | |
| if self.device == "cuda:0": | |
| logger.info(f"GPU: {torch.cuda.get_device_name()}") | |
| logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") | |
| # Get HF token from environment | |
| hf_token = os.getenv("HF_TOKEN") | |
| logger.info("Loading Llama-3.1-8B-Instruct model...") | |
| base_model_name = "meta-llama/Llama-3.1-8B-Instruct" | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| base_model_name, | |
| use_fast=True, | |
| trust_remote_code=True, | |
| token=hf_token | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32, | |
| device_map={"": 0} if self.device == "cuda:0" else None, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True, | |
| token=hf_token | |
| ) | |
| if self.device == "cuda:0": | |
| self.model = self.model.to(self.device) | |
| self.model_loaded = True | |
| logger.info("Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| self.model_loaded = False | |
| # Initialize model manager | |
| model_manager = ModelManager() | |
| def generate_response(prompt, temperature=0.8): | |
| """Simple function to generate a response from a prompt""" | |
| if not model_manager.model_loaded: | |
| return "Model not loaded yet. Please wait..." | |
| try: | |
| # Create the Llama-3.1 chat format | |
| formatted_prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> | |
| {prompt} | |
| <|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
| """ | |
| # Determine context window and USE ABSOLUTE MAXIMUM | |
| try: | |
| max_ctx = getattr(model_manager.model.config, "max_position_embeddings", 131072) # Llama 3.1 supports up to 131k | |
| except Exception: | |
| max_ctx = 131072 # Use maximum possible | |
| logger.info(f"Model max context: {max_ctx} tokens") | |
| # Detect if this is a Chain of Thinking request | |
| is_cot_request = ("chain-of-thinking" in prompt.lower() or | |
| "chain of thinking" in prompt.lower() or | |
| "Return exactly this JSON array" in prompt or | |
| ("verbatim" in prompt.lower() and "json array" in prompt.lower())) | |
| # MAXIMIZE GENERATION TOKENS - use most of context for generation | |
| if is_cot_request: | |
| # For CoT, use MAXIMUM possible generation tokens | |
| gen_max_new_tokens = 16384 # Very high limit for complete responses | |
| min_tokens = 2000 # High minimum to force complete generation | |
| # Allow most of context for input | |
| allowed_input_tokens = max_ctx - gen_max_new_tokens - 100 # Small safety buffer | |
| logger.info(f"CoT REQUEST - MAXIMIZED: min_tokens={min_tokens}, max_new_tokens={gen_max_new_tokens}, input_limit={allowed_input_tokens}") | |
| else: | |
| # Standard requests | |
| gen_max_new_tokens = 8192 | |
| min_tokens = 200 | |
| allowed_input_tokens = max_ctx - gen_max_new_tokens - 100 | |
| # Tokenize the input with safe truncation | |
| inputs = model_manager.tokenizer( | |
| formatted_prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=allowed_input_tokens | |
| ) | |
| # Move inputs to the same device as the model | |
| if model_manager.device == "cuda:0": | |
| model_device = next(model_manager.model.parameters()).device | |
| inputs = {k: v.to(model_device) for k, v in inputs.items()} | |
| # Generate response with MAXIMUM settings | |
| with torch.no_grad(): | |
| outputs = model_manager.model.generate( | |
| **inputs, | |
| max_new_tokens=gen_max_new_tokens, | |
| min_new_tokens=min_tokens, | |
| temperature=temperature, | |
| top_p=0.95, | |
| do_sample=True, | |
| num_beams=1, | |
| pad_token_id=model_manager.tokenizer.eos_token_id, | |
| eos_token_id=model_manager.tokenizer.eos_token_id, | |
| early_stopping=False, # Never stop early | |
| repetition_penalty=1.05, | |
| no_repeat_ngram_size=0, | |
| length_penalty=1.0, | |
| # Force generation to continue | |
| use_cache=True | |
| ) | |
| # Decode the response | |
| generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Log generation details for debugging | |
| input_length = inputs['input_ids'].shape[1] | |
| output_length = outputs[0].shape[0] | |
| generated_length = output_length - input_length | |
| logger.info(f"Generation stats - Input: {input_length} tokens, Generated: {generated_length} tokens, Min required: {min_tokens}") | |
| if generated_length < min_tokens: | |
| logger.warning(f"Generated {generated_length} tokens but minimum was {min_tokens} - response may be truncated") | |
| # Post-decode guard: if a top-level JSON array closes, trim to the first full array | |
| # This helps prevent trailing prose like 'assistant' or 'Message'. | |
| try: | |
| # Track both bracket and brace depth to find first complete JSON structure | |
| bracket_depth = 0 # [ ] | |
| brace_depth = 0 # { } | |
| in_string = False | |
| escape_next = False | |
| start_idx = None | |
| end_idx = None | |
| for i, ch in enumerate(generated_text): | |
| # Handle string escaping | |
| if escape_next: | |
| escape_next = False | |
| continue | |
| if ch == '\\': | |
| escape_next = True | |
| continue | |
| # Track if we're inside a string | |
| if ch == '"' and not escape_next: | |
| in_string = not in_string | |
| continue | |
| # Only count brackets/braces outside of strings | |
| if not in_string: | |
| if ch == '[': | |
| if bracket_depth == 0 and brace_depth == 0 and start_idx is None: | |
| start_idx = i | |
| bracket_depth += 1 | |
| elif ch == ']': | |
| bracket_depth = max(0, bracket_depth - 1) | |
| if bracket_depth == 0 and brace_depth == 0 and start_idx is not None: | |
| end_idx = i | |
| break | |
| elif ch == '{': | |
| brace_depth += 1 | |
| elif ch == '}': | |
| brace_depth = max(0, brace_depth - 1) | |
| if start_idx is not None and end_idx is not None and end_idx > start_idx: | |
| # Extract just the complete JSON array | |
| json_text = generated_text[start_idx:end_idx+1] | |
| logger.info(f"Extracted complete JSON array of length {len(json_text)}") | |
| generated_text = json_text | |
| elif start_idx is not None: | |
| # Found start but no end - response was truncated | |
| logger.warning("JSON array started but never closed - response truncated") | |
| # Try to extract what we have and let the client handle it | |
| generated_text = generated_text[start_idx:] | |
| except Exception as e: | |
| logger.warning(f"Error in JSON extraction: {e}") | |
| pass | |
| # Extract just the assistant's response | |
| if "<|start_header_id|>assistant<|end_header_id|>" in generated_text: | |
| response = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip() | |
| else: | |
| # Better fallback: look for the start of actual content (JSON or text) | |
| import re | |
| # Look for JSON array or object start | |
| json_match = re.search(r'(\[|\{)', generated_text) | |
| if json_match and json_match.start() > len(formatted_prompt) // 2: | |
| response = generated_text[json_match.start():].strip() | |
| else: | |
| # Look for the end of the prompt pattern | |
| prompt_end_patterns = [ | |
| "<|end_header_id|>", | |
| "<|eot_id|>", | |
| "assistant", | |
| "\n\n" | |
| ] | |
| response = generated_text | |
| for pattern in prompt_end_patterns: | |
| if pattern in generated_text: | |
| parts = generated_text.split(pattern) | |
| if len(parts) > 1: | |
| # Take the last substantial part | |
| candidate = parts[-1].strip() | |
| if len(candidate) > 20: # Ensure it's not too short | |
| response = candidate | |
| break | |
| # Ultimate fallback - just return everything after a reasonable point | |
| if response == generated_text: | |
| # Skip approximately the prompt length but be conservative | |
| skip_chars = min(len(formatted_prompt) // 2, len(generated_text) // 3) | |
| response = generated_text[skip_chars:].strip() | |
| logger.info(f"Generated response length: {len(response)} characters") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| return f"Error: {str(e)}" | |
| def respond(message, history, temperature): | |
| """Gradio interface function for chat""" | |
| response = generate_response(message, temperature) | |
| # Update history | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": response}) | |
| return history, "" | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Question Generation API") as demo: | |
| gr.Markdown("# Simple LLM API") | |
| gr.Markdown("Send a prompt and get a response. No templates, just direct model interaction.") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| label="Chat", | |
| type="messages", | |
| height=400 | |
| ) | |
| msg = gr.Textbox( | |
| label="Message", | |
| placeholder="Enter your prompt here...", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button("Send", variant="primary") | |
| clear = gr.Button("Clear") | |
| with gr.Column(scale=1): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.8, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative" | |
| ) | |
| gr.Markdown(""" | |
| ### API Usage | |
| This model accepts any prompt and returns a response. | |
| For JSON responses, include instructions in your prompt like: | |
| - "Return as a JSON array" | |
| - "Format as JSON" | |
| - "List as JSON" | |
| The model will follow your instructions. | |
| """) | |
| # Set up event handlers | |
| submit.click(respond, [msg, chatbot, temperature], [chatbot, msg]) | |
| msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg]) | |
| clear.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |