File size: 12,961 Bytes
0cdc4eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import os
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ModelManager:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = None
        self.model_loaded = False
        self.load_model()

    def load_model(self):
        """Load the model and tokenizer"""
        try:
            logger.info("Starting model loading...")
            
            # Check if CUDA is available
            if torch.cuda.is_available():
                torch.cuda.set_device(0)
                self.device = "cuda:0"
            else:
                self.device = "cpu"
            logger.info(f"Using device: {self.device}")
            
            if self.device == "cuda:0":
                logger.info(f"GPU: {torch.cuda.get_device_name()}")
                logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
            
            # Get HF token from environment
            hf_token = os.getenv("HF_TOKEN")
            
            logger.info("Loading Llama-3.1-8B-Instruct model...")
            base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
            
            self.tokenizer = AutoTokenizer.from_pretrained(
                base_model_name,
                use_fast=True,
                trust_remote_code=True,
                token=hf_token
            )
            
            self.model = AutoModelForCausalLM.from_pretrained(
                base_model_name,
                torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32,
                device_map={"": 0} if self.device == "cuda:0" else None,
                trust_remote_code=True,
                low_cpu_mem_usage=True,
                use_safetensors=True,
                token=hf_token
            )
            
            if self.device == "cuda:0":
                self.model = self.model.to(self.device)
            
            self.model_loaded = True
            logger.info("Model loaded successfully!")
            
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            self.model_loaded = False

# Initialize model manager
model_manager = ModelManager()

def generate_response(prompt, temperature=0.8):
    """Simple function to generate a response from a prompt"""
    if not model_manager.model_loaded:
        return "Model not loaded yet. Please wait..."
    
    try:
        # Create the Llama-3.1 chat format
        formatted_prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

{prompt}

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""
        
        # Determine context window and USE ABSOLUTE MAXIMUM
        try:
            max_ctx = getattr(model_manager.model.config, "max_position_embeddings", 131072)  # Llama 3.1 supports up to 131k
        except Exception:
            max_ctx = 131072  # Use maximum possible
        
        logger.info(f"Model max context: {max_ctx} tokens")

        # Detect if this is a Chain of Thinking request
        is_cot_request = ("chain-of-thinking" in prompt.lower() or 
                         "chain of thinking" in prompt.lower() or
                         "Return exactly this JSON array" in prompt or
                         ("verbatim" in prompt.lower() and "json array" in prompt.lower()))
        
        # MAXIMIZE GENERATION TOKENS - use most of context for generation
        if is_cot_request:
            # For CoT, use MAXIMUM possible generation tokens
            gen_max_new_tokens = 16384  # Very high limit for complete responses
            min_tokens = 2000  # High minimum to force complete generation
            # Allow most of context for input
            allowed_input_tokens = max_ctx - gen_max_new_tokens - 100  # Small safety buffer
            logger.info(f"CoT REQUEST - MAXIMIZED: min_tokens={min_tokens}, max_new_tokens={gen_max_new_tokens}, input_limit={allowed_input_tokens}")
        else:
            # Standard requests
            gen_max_new_tokens = 8192
            min_tokens = 200
            allowed_input_tokens = max_ctx - gen_max_new_tokens - 100

        # Tokenize the input with safe truncation
        inputs = model_manager.tokenizer(
            formatted_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=allowed_input_tokens
        )
        
        # Move inputs to the same device as the model
        if model_manager.device == "cuda:0":
            model_device = next(model_manager.model.parameters()).device
            inputs = {k: v.to(model_device) for k, v in inputs.items()}
        
        # Generate response with MAXIMUM settings
        with torch.no_grad():
            outputs = model_manager.model.generate(
                **inputs,
                max_new_tokens=gen_max_new_tokens,
                min_new_tokens=min_tokens,
                temperature=temperature,
                top_p=0.95,
                do_sample=True,
                num_beams=1,
                pad_token_id=model_manager.tokenizer.eos_token_id,
                eos_token_id=model_manager.tokenizer.eos_token_id,
                early_stopping=False,  # Never stop early
                repetition_penalty=1.05,
                no_repeat_ngram_size=0,
                length_penalty=1.0,
                # Force generation to continue
                use_cache=True
            )
        
        # Decode the response
        generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Log generation details for debugging
        input_length = inputs['input_ids'].shape[1]
        output_length = outputs[0].shape[0]
        generated_length = output_length - input_length
        logger.info(f"Generation stats - Input: {input_length} tokens, Generated: {generated_length} tokens, Min required: {min_tokens}")
        
        if generated_length < min_tokens:
            logger.warning(f"Generated {generated_length} tokens but minimum was {min_tokens} - response may be truncated")
        
        # Post-decode guard: if a top-level JSON array closes, trim to the first full array
        # This helps prevent trailing prose like 'assistant' or 'Message'.
        try:
            # Track both bracket and brace depth to find first complete JSON structure
            bracket_depth = 0  # [ ]
            brace_depth = 0    # { }
            in_string = False
            escape_next = False
            start_idx = None
            end_idx = None
            
            for i, ch in enumerate(generated_text):
                # Handle string escaping
                if escape_next:
                    escape_next = False
                    continue
                    
                if ch == '\\':
                    escape_next = True
                    continue
                    
                # Track if we're inside a string
                if ch == '"' and not escape_next:
                    in_string = not in_string
                    continue
                
                # Only count brackets/braces outside of strings
                if not in_string:
                    if ch == '[':
                        if bracket_depth == 0 and brace_depth == 0 and start_idx is None:
                            start_idx = i
                        bracket_depth += 1
                    elif ch == ']':
                        bracket_depth = max(0, bracket_depth - 1)
                        if bracket_depth == 0 and brace_depth == 0 and start_idx is not None:
                            end_idx = i
                            break
                    elif ch == '{':
                        brace_depth += 1
                    elif ch == '}':
                        brace_depth = max(0, brace_depth - 1)
                        
            if start_idx is not None and end_idx is not None and end_idx > start_idx:
                # Extract just the complete JSON array
                json_text = generated_text[start_idx:end_idx+1]
                logger.info(f"Extracted complete JSON array of length {len(json_text)}")
                generated_text = json_text
            elif start_idx is not None:
                # Found start but no end - response was truncated
                logger.warning("JSON array started but never closed - response truncated")
                # Try to extract what we have and let the client handle it
                generated_text = generated_text[start_idx:]
        except Exception as e:
            logger.warning(f"Error in JSON extraction: {e}")
            pass
        
        # Extract just the assistant's response
        if "<|start_header_id|>assistant<|end_header_id|>" in generated_text:
            response = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
        else:
            # Better fallback: look for the start of actual content (JSON or text)
            import re
            
            # Look for JSON array or object start
            json_match = re.search(r'(\[|\{)', generated_text)
            if json_match and json_match.start() > len(formatted_prompt) // 2:
                response = generated_text[json_match.start():].strip()
            else:
                # Look for the end of the prompt pattern
                prompt_end_patterns = [
                    "<|end_header_id|>",
                    "<|eot_id|>",
                    "assistant",
                    "\n\n"
                ]
                
                response = generated_text
                for pattern in prompt_end_patterns:
                    if pattern in generated_text:
                        parts = generated_text.split(pattern)
                        if len(parts) > 1:
                            # Take the last substantial part
                            candidate = parts[-1].strip()
                            if len(candidate) > 20:  # Ensure it's not too short
                                response = candidate
                                break
                
                # Ultimate fallback - just return everything after a reasonable point
                if response == generated_text:
                    # Skip approximately the prompt length but be conservative
                    skip_chars = min(len(formatted_prompt) // 2, len(generated_text) // 3)
                    response = generated_text[skip_chars:].strip()
        
        logger.info(f"Generated response length: {len(response)} characters")
        return response
        
    except Exception as e:
        logger.error(f"Error generating response: {str(e)}")
        return f"Error: {str(e)}"

def respond(message, history, temperature):
    """Gradio interface function for chat"""
    response = generate_response(message, temperature)
    
    # Update history
    history.append({"role": "user", "content": message})
    history.append({"role": "assistant", "content": response})
    
    return history, ""

# Create the Gradio interface
with gr.Blocks(title="Question Generation API") as demo:
    gr.Markdown("# Simple LLM API")
    gr.Markdown("Send a prompt and get a response. No templates, just direct model interaction.")
    
    with gr.Row():
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(
                label="Chat",
                type="messages",
                height=400
            )
            msg = gr.Textbox(
                label="Message",
                placeholder="Enter your prompt here...",
                lines=3
            )
            with gr.Row():
                submit = gr.Button("Send", variant="primary")
                clear = gr.Button("Clear")
        
        with gr.Column(scale=1):
            temperature = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=0.8,
                step=0.1,
                label="Temperature",
                info="Higher = more creative"
            )
            gr.Markdown("""
            ### API Usage
            This model accepts any prompt and returns a response.
            
            For JSON responses, include instructions in your prompt like:
            - "Return as a JSON array"
            - "Format as JSON"
            - "List as JSON"
            
            The model will follow your instructions.
            """)
    
    # Set up event handlers
    submit.click(respond, [msg, chatbot, temperature], [chatbot, msg])
    msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg])
    clear.click(lambda: ([], ""), outputs=[chatbot, msg])

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )