import os import logging import torch from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import json import re # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelManager: def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False self.load_model() def load_model(self): """Load the model and tokenizer""" try: logger.info("Starting model loading...") # Check if CUDA is available if torch.cuda.is_available(): torch.cuda.set_device(0) self.device = "cuda:0" else: self.device = "cpu" logger.info(f"Using device: {self.device}") if self.device == "cuda:0": logger.info(f"GPU: {torch.cuda.get_device_name()}") logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") # Get HF token from environment hf_token = os.getenv("HF_TOKEN") logger.info("Loading Llama-3.1-8B-Instruct model...") base_model_name = "meta-llama/Llama-3.1-8B-Instruct" self.tokenizer = AutoTokenizer.from_pretrained( base_model_name, use_fast=True, trust_remote_code=True, token=hf_token ) self.model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32, device_map="auto" if self.device == "cuda:0" else None, trust_remote_code=True, token=hf_token ) # Set pad token if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model_loaded = True logger.info("✅ Model loaded successfully!") except Exception as e: logger.error(f"❌ Error loading model: {str(e)}") self.model_loaded = False def generate_response(prompt, temperature=0.8, model_manager=None): """ELEGANT AI ARCHITECT SOLUTION - Clean, simple, effective""" if not model_manager or not model_manager.model_loaded: return "Model not loaded" try: # Detect request type is_cot_request = any(phrase in prompt.lower() for phrase in [ "return exactly this json array", "chain of thinking", "verbatim", "json array (no other text)" ]) # Get actual model context max_context = getattr(model_manager.model.config, "max_position_embeddings", 8192) logger.info(f"Model context: {max_context} tokens") # SIMPLE, CLEAR PROMPT FORMATTING if is_cot_request: system_msg = "You are an expert at generating JSON training data. Return only valid JSON arrays as requested, no additional text." else: system_msg = "You are a helpful AI assistant generating high-quality training data." formatted_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> {system_msg} <|eot_id|><|start_header_id|>user<|end_header_id|> {prompt} <|eot_id|><|start_header_id|>assistant<|end_header_id|> """ # SMART TOKEN ALLOCATION if is_cot_request: # CoT needs substantial output for complete JSON max_new_tokens = 3000 # Generous but not excessive min_new_tokens = 500 # Ensure JSON completion else: max_new_tokens = 1500 min_new_tokens = 50 # Reserve space for input max_input_tokens = max_context - max_new_tokens - 100 logger.info(f"Token plan: Input≤{max_input_tokens}, Output={min_new_tokens}-{max_new_tokens}") # Tokenize inputs = model_manager.tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=max_input_tokens ) # Move to device if model_manager.device == "cuda:0": inputs = {k: v.to(next(model_manager.model.parameters()).device) for k, v in inputs.items()} # CLEAN GENERATION with torch.no_grad(): outputs = model_manager.model.generate( **inputs, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, temperature=temperature, top_p=0.9, do_sample=True, pad_token_id=model_manager.tokenizer.eos_token_id, early_stopping=False, repetition_penalty=1.1 ) # Decode full_response = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True) # Log stats input_len = inputs['input_ids'].shape[1] output_len = outputs[0].shape[0] generated_len = output_len - input_len logger.info(f"Generated {generated_len} tokens (min was {min_new_tokens})") # CLEAN EXTRACTION if "<|start_header_id|>assistant<|end_header_id|>" in full_response: response = full_response.split("<|start_header_id|>assistant<|end_header_id|>", 1)[-1].strip() else: # Fallback response = full_response[len(formatted_prompt):].strip() # For CoT, extract clean JSON if possible if is_cot_request and '[' in response and ']' in response: # Find the most complete JSON array json_pattern = r'\[(?:[^[\]]+|\[[^\]]*\])*\]' matches = re.findall(json_pattern, response, re.DOTALL) if matches: # Pick the longest match (most complete) best_match = max(matches, key=len) # Verify it has reasonable content if '"user"' in best_match and '"assistant"' in best_match: logger.info(f"Extracted JSON: {len(best_match)} chars") response = best_match logger.info(f"Final response: {len(response)} chars") return response.strip() except Exception as e: logger.error(f"Generation error: {e}") return f"Error: {e}" # Initialize model model_manager = ModelManager() def respond(message, history, temperature): """Gradio interface function - fixed for proper format""" try: response = generate_response(message, temperature, model_manager) # Return just the response for the simple interface return response except Exception as e: logger.error(f"Error in respond: {e}") return f"Error: {e}" # API function for external calls def api_respond(message, history=None, temperature=0.8, json_mode=None, template=None): """API endpoint matching original client expectations""" try: response = generate_response(message, temperature, model_manager) # Return in original format that client expects return [[ {"role": "user", "metadata": None, "content": message, "options": None}, {"role": "assistant", "metadata": None, "content": response, "options": None} ], ""] except Exception as e: logger.error(f"API Error: {e}") return [[ {"role": "user", "metadata": None, "content": message, "options": None}, {"role": "assistant", "metadata": None, "content": f"Error: {e}", "options": None} ], ""] # Create Gradio interface with gr.Blocks(title="Question Generation API") as demo: gr.Markdown("# Question Generation API - Elegant Architecture") with gr.Row(): with gr.Column(): message_input = gr.Textbox(label="Message", placeholder="Enter your prompt...", lines=5) temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature") submit_btn = gr.Button("Generate", variant="primary") with gr.Column(): response_output = gr.Textbox(label="Response", lines=15, max_lines=30) # Simple UI function def ui_respond(message, temperature): return generate_response(message, temperature, model_manager) submit_btn.click(ui_respond, inputs=[message_input, temperature_input], outputs=[response_output]) # Add API endpoint within the Blocks interface with gr.Tab("API"): with gr.Row(): api_message = gr.Textbox(label="Message", lines=3) api_temp = gr.Number(value=0.8, label="Temperature") api_submit = gr.Button("Call API") api_output = gr.JSON(label="API Response") api_submit.click(api_respond, inputs=[api_message, gr.State([]), api_temp], outputs=[api_output]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False)