Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| import json | |
| import re | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.model_loaded = False | |
| self.load_model() | |
| def load_model(self): | |
| """Load the model and tokenizer""" | |
| try: | |
| logger.info("Starting model loading...") | |
| # Check if CUDA is available | |
| if torch.cuda.is_available(): | |
| torch.cuda.set_device(0) | |
| self.device = "cuda:0" | |
| else: | |
| self.device = "cpu" | |
| logger.info(f"Using device: {self.device}") | |
| if self.device == "cuda:0": | |
| logger.info(f"GPU: {torch.cuda.get_device_name()}") | |
| logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") | |
| # Get HF token from environment | |
| hf_token = os.getenv("HF_TOKEN") | |
| logger.info("Loading Llama-3.1-8B-Instruct model...") | |
| base_model_name = "meta-llama/Llama-3.1-8B-Instruct" | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| base_model_name, | |
| use_fast=True, | |
| trust_remote_code=True, | |
| token=hf_token | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32, | |
| device_map="auto" if self.device == "cuda:0" else None, | |
| trust_remote_code=True, | |
| token=hf_token, | |
| attn_implementation="eager" # Use eager attention (compatible) | |
| ) | |
| # Set pad token | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model_loaded = True | |
| logger.info("β Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"β Error loading model: {str(e)}") | |
| self.model_loaded = False | |
| def generate_response(prompt, temperature=0.8): | |
| """ZERO TRUNCATION GENERATION - Never cut anything!""" | |
| global model_manager | |
| if not model_manager or not model_manager.model_loaded: | |
| return "Model not loaded" | |
| try: | |
| # Detect CoT requests | |
| is_cot = any(phrase in prompt.lower() for phrase in [ | |
| "return exactly this json array", | |
| "chain of thinking", | |
| "verbatim" | |
| ]) | |
| logger.info(f"π― Request type: {'CoT' if is_cot else 'Standard'}") | |
| # Simple system message | |
| if is_cot: | |
| system = "You are an expert at generating JSON training data exactly as requested." | |
| else: | |
| system = "You are a helpful AI assistant." | |
| # Format prompt | |
| formatted = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> | |
| {system} | |
| <|eot_id|><|start_header_id|>user<|end_header_id|> | |
| {prompt} | |
| <|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
| """ | |
| # Optimized token limits for speed | |
| if is_cot: | |
| max_new = 1500 # Reduced for speed | |
| min_new = 400 # Reduced minimum | |
| else: | |
| max_new = 800 # Significantly reduced for speed | |
| min_new = 50 # Lower minimum | |
| max_input = 6000 # Safe input limit | |
| logger.info(f"π’ Token allocation: Inputβ€{max_input}, Output={min_new}-{max_new}") | |
| # Tokenize | |
| inputs = model_manager.tokenizer( | |
| formatted, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=max_input | |
| ) | |
| # Move to device | |
| if model_manager.device == "cuda:0": | |
| inputs = {k: v.to(next(model_manager.model.parameters()).device) for k, v in inputs.items()} | |
| logger.info("π Starting generation...") | |
| # Generate with generous parameters | |
| with torch.no_grad(): | |
| outputs = model_manager.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new, | |
| min_new_tokens=min_new, | |
| temperature=temperature, | |
| top_p=0.9, | |
| do_sample=True, | |
| num_beams=1, # Greedy search for speed | |
| pad_token_id=model_manager.tokenizer.eos_token_id, | |
| early_stopping=True, # Enable early stopping for speed | |
| repetition_penalty=1.1, | |
| use_cache=True | |
| ) | |
| # Decode the COMPLETE response | |
| full_response = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| logger.info(f"π Full response length: {len(full_response)} chars") | |
| logger.info(f"π Response preview: {full_response[:200]}...") | |
| # ZERO TRUNCATION EXTRACTION - Find content intelligently but never cut | |
| response = full_response | |
| # Look for the assistant response marker | |
| assistant_marker = "<|start_header_id|>assistant<|end_header_id|>" | |
| if assistant_marker in full_response: | |
| # Find the position after the marker | |
| marker_pos = full_response.find(assistant_marker) | |
| if marker_pos != -1: | |
| # Start after the marker + some whitespace | |
| start_pos = marker_pos + len(assistant_marker) | |
| # Skip any immediate whitespace/newlines | |
| while start_pos < len(full_response) and full_response[start_pos] in ' \n\r\t': | |
| start_pos += 1 | |
| if start_pos < len(full_response): | |
| response = full_response[start_pos:] | |
| logger.info(f"βοΈ Extracted after assistant marker: {len(response)} chars") | |
| else: | |
| logger.info("π Marker found but no content after, using full response") | |
| else: | |
| logger.info("π Marker search failed, using full response") | |
| else: | |
| logger.info("π No assistant marker found, using full response") | |
| # For CoT, if we have a JSON array, extract it cleanly | |
| if is_cot and '[' in response and ']' in response: | |
| # Find the outermost JSON array | |
| first_bracket = response.find('[') | |
| last_bracket = response.rfind(']') | |
| if first_bracket != -1 and last_bracket != -1 and last_bracket > first_bracket: | |
| json_candidate = response[first_bracket:last_bracket+1] | |
| # Validate it contains the expected structure | |
| if '"user"' in json_candidate and '"assistant"' in json_candidate: | |
| # Count the objects to make sure we have multiple items | |
| user_count = json_candidate.count('"user"') | |
| if user_count >= 2: # Should have at least 2 user/assistant pairs | |
| response = json_candidate | |
| logger.info(f"π― Extracted JSON array with {user_count} items: {len(response)} chars") | |
| else: | |
| logger.info(f"β οΈ JSON array has only {user_count} items, using full response") | |
| else: | |
| logger.info("β οΈ JSON candidate failed validation, using full response") | |
| # Final response | |
| response = response.strip() | |
| logger.info(f"β FINAL response: {len(response)} chars") | |
| logger.info(f"π¬ Starts with: {response[:150]}...") | |
| logger.info(f"π Ends with: ...{response[-150:]}") | |
| return response | |
| except Exception as e: | |
| logger.error(f"π₯ Generation error: {e}") | |
| return f"Error: {e}" | |
| # Initialize model ONCE | |
| model_manager = ModelManager() | |
| def api_respond(message, history_str, temperature, json_mode, template): | |
| """ZERO TRUNCATION API - Pure content, no wrappers""" | |
| try: | |
| logger.info(f"π¨ API Request: {len(message)} chars, temp={temperature}") | |
| response = generate_response(message, temperature) | |
| logger.info(f"π€ API Response: {len(response)} chars") | |
| return response | |
| except Exception as e: | |
| logger.error(f"π₯ API Error: {e}") | |
| return f"Error: {e}" | |
| # BULLETPROOF GRADIO INTERFACE | |
| demo = gr.Interface( | |
| fn=api_respond, | |
| inputs=[ | |
| gr.Textbox(label="Message", lines=8, placeholder="Enter your prompt here..."), | |
| gr.Textbox(label="History", value="[]", visible=False), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature"), | |
| gr.Textbox(label="JSON Mode", value="", visible=False), | |
| gr.Textbox(label="Template", value="", visible=False) | |
| ], | |
| outputs=gr.Textbox(label="Response", lines=20, max_lines=50), | |
| title="π― Question Generation API - ZERO TRUNCATION", | |
| description="Rebuilt from scratch with ZERO text cutting. Generates complete responses every time.", | |
| api_name="respond" | |
| ) | |
| if __name__ == "__main__": | |
| # Enable queue with concurrency limit of 10 | |
| demo.queue( | |
| default_concurrency_limit=10, # Handle 10 concurrent requests | |
| max_size=100 # Allow up to 100 requests in queue | |
| ).launch(server_name="0.0.0.0", server_port=7860, share=False) |