Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| import json | |
| import re | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.model_loaded = False | |
| self.load_model() | |
| def load_model(self): | |
| """Load the model and tokenizer""" | |
| try: | |
| logger.info("Starting model loading...") | |
| # Check if CUDA is available | |
| if torch.cuda.is_available(): | |
| torch.cuda.set_device(0) | |
| self.device = "cuda:0" | |
| else: | |
| self.device = "cpu" | |
| logger.info(f"Using device: {self.device}") | |
| if self.device == "cuda:0": | |
| logger.info(f"GPU: {torch.cuda.get_device_name()}") | |
| logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") | |
| # Get HF token from environment | |
| hf_token = os.getenv("HF_TOKEN") | |
| logger.info("Loading Llama-3.1-8B-Instruct model...") | |
| base_model_name = "meta-llama/Llama-3.1-8B-Instruct" | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| base_model_name, | |
| use_fast=True, | |
| trust_remote_code=True, | |
| token=hf_token | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32, | |
| device_map="auto" if self.device == "cuda:0" else None, | |
| trust_remote_code=True, | |
| token=hf_token | |
| ) | |
| # Set pad token | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model_loaded = True | |
| logger.info("✅ Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"❌ Error loading model: {str(e)}") | |
| self.model_loaded = False | |
| def generate_response(prompt, temperature=0.8, model_manager=None): | |
| """ELEGANT AI ARCHITECT SOLUTION - Clean, simple, effective""" | |
| if not model_manager or not model_manager.model_loaded: | |
| return "Model not loaded" | |
| try: | |
| # Detect request type | |
| is_cot_request = any(phrase in prompt.lower() for phrase in [ | |
| "return exactly this json array", | |
| "chain of thinking", | |
| "verbatim", | |
| "json array (no other text)" | |
| ]) | |
| # Get actual model context | |
| max_context = getattr(model_manager.model.config, "max_position_embeddings", 8192) | |
| logger.info(f"Model context: {max_context} tokens") | |
| # SIMPLE, CLEAR PROMPT FORMATTING | |
| if is_cot_request: | |
| system_msg = "You are an expert at generating JSON training data. Return only valid JSON arrays as requested, no additional text." | |
| else: | |
| system_msg = "You are a helpful AI assistant generating high-quality training data." | |
| formatted_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> | |
| {system_msg} | |
| <|eot_id|><|start_header_id|>user<|end_header_id|> | |
| {prompt} | |
| <|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
| """ | |
| # SMART TOKEN ALLOCATION | |
| if is_cot_request: | |
| # CoT needs substantial output for complete JSON | |
| max_new_tokens = 3000 # Generous but not excessive | |
| min_new_tokens = 500 # Ensure JSON completion | |
| else: | |
| max_new_tokens = 1500 | |
| min_new_tokens = 50 | |
| # Reserve space for input | |
| max_input_tokens = max_context - max_new_tokens - 100 | |
| logger.info(f"Token plan: Input≤{max_input_tokens}, Output={min_new_tokens}-{max_new_tokens}") | |
| # Tokenize | |
| inputs = model_manager.tokenizer( | |
| formatted_prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=max_input_tokens | |
| ) | |
| # Move to device | |
| if model_manager.device == "cuda:0": | |
| inputs = {k: v.to(next(model_manager.model.parameters()).device) for k, v in inputs.items()} | |
| # CLEAN GENERATION | |
| with torch.no_grad(): | |
| outputs = model_manager.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| min_new_tokens=min_new_tokens, | |
| temperature=temperature, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=model_manager.tokenizer.eos_token_id, | |
| early_stopping=False, | |
| repetition_penalty=1.1 | |
| ) | |
| # Decode | |
| full_response = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Log stats | |
| input_len = inputs['input_ids'].shape[1] | |
| output_len = outputs[0].shape[0] | |
| generated_len = output_len - input_len | |
| logger.info(f"Generated {generated_len} tokens (min was {min_new_tokens})") | |
| # CLEAN EXTRACTION | |
| if "<|start_header_id|>assistant<|end_header_id|>" in full_response: | |
| response = full_response.split("<|start_header_id|>assistant<|end_header_id|>", 1)[-1].strip() | |
| else: | |
| # Fallback | |
| response = full_response[len(formatted_prompt):].strip() | |
| # For CoT, extract clean JSON if possible | |
| if is_cot_request and '[' in response and ']' in response: | |
| # Find the most complete JSON array | |
| json_pattern = r'\[(?:[^[\]]+|\[[^\]]*\])*\]' | |
| matches = re.findall(json_pattern, response, re.DOTALL) | |
| if matches: | |
| # Pick the longest match (most complete) | |
| best_match = max(matches, key=len) | |
| # Verify it has reasonable content | |
| if '"user"' in best_match and '"assistant"' in best_match: | |
| logger.info(f"Extracted JSON: {len(best_match)} chars") | |
| response = best_match | |
| logger.info(f"Final response: {len(response)} chars") | |
| return response.strip() | |
| except Exception as e: | |
| logger.error(f"Generation error: {e}") | |
| return f"Error: {e}" | |
| # Initialize model | |
| model_manager = ModelManager() | |
| def respond(message, history, temperature): | |
| """Gradio interface function - fixed for proper format""" | |
| try: | |
| response = generate_response(message, temperature, model_manager) | |
| # Return just the response for the simple interface | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error in respond: {e}") | |
| return f"Error: {e}" | |
| # API function for external calls | |
| def api_respond(message, history=None, temperature=0.8, json_mode=None, template=None): | |
| """API endpoint matching original client expectations""" | |
| try: | |
| response = generate_response(message, temperature, model_manager) | |
| # Return in original format that client expects | |
| return [[ | |
| {"role": "user", "metadata": None, "content": message, "options": None}, | |
| {"role": "assistant", "metadata": None, "content": response, "options": None} | |
| ], ""] | |
| except Exception as e: | |
| logger.error(f"API Error: {e}") | |
| return [[ | |
| {"role": "user", "metadata": None, "content": message, "options": None}, | |
| {"role": "assistant", "metadata": None, "content": f"Error: {e}", "options": None} | |
| ], ""] | |
| # Create Gradio interface | |
| with gr.Blocks(title="Question Generation API") as demo: | |
| gr.Markdown("# Question Generation API - Elegant Architecture") | |
| with gr.Row(): | |
| with gr.Column(): | |
| message_input = gr.Textbox(label="Message", placeholder="Enter your prompt...", lines=5) | |
| temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature") | |
| submit_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| response_output = gr.Textbox(label="Response", lines=15, max_lines=30) | |
| # Simple UI function | |
| def ui_respond(message, temperature): | |
| return generate_response(message, temperature, model_manager) | |
| submit_btn.click(ui_respond, inputs=[message_input, temperature_input], outputs=[response_output]) | |
| # Add API endpoint within the Blocks interface | |
| with gr.Tab("API"): | |
| with gr.Row(): | |
| api_message = gr.Textbox(label="Message", lines=3) | |
| api_temp = gr.Number(value=0.8, label="Temperature") | |
| api_submit = gr.Button("Call API") | |
| api_output = gr.JSON(label="API Response") | |
| api_submit.click(api_respond, inputs=[api_message, gr.State([]), api_temp], outputs=[api_output]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |