Spaces:
Sleeping
Sleeping
File size: 12,961 Bytes
0cdc4eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 | import os
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
self.load_model()
def load_model(self):
"""Load the model and tokenizer"""
try:
logger.info("Starting model loading...")
# Check if CUDA is available
if torch.cuda.is_available():
torch.cuda.set_device(0)
self.device = "cuda:0"
else:
self.device = "cpu"
logger.info(f"Using device: {self.device}")
if self.device == "cuda:0":
logger.info(f"GPU: {torch.cuda.get_device_name()}")
logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
# Get HF token from environment
hf_token = os.getenv("HF_TOKEN")
logger.info("Loading Llama-3.1-8B-Instruct model...")
base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
self.tokenizer = AutoTokenizer.from_pretrained(
base_model_name,
use_fast=True,
trust_remote_code=True,
token=hf_token
)
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32,
device_map={"": 0} if self.device == "cuda:0" else None,
trust_remote_code=True,
low_cpu_mem_usage=True,
use_safetensors=True,
token=hf_token
)
if self.device == "cuda:0":
self.model = self.model.to(self.device)
self.model_loaded = True
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
self.model_loaded = False
# Initialize model manager
model_manager = ModelManager()
def generate_response(prompt, temperature=0.8):
"""Simple function to generate a response from a prompt"""
if not model_manager.model_loaded:
return "Model not loaded yet. Please wait..."
try:
# Create the Llama-3.1 chat format
formatted_prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{prompt}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
# Determine context window and USE ABSOLUTE MAXIMUM
try:
max_ctx = getattr(model_manager.model.config, "max_position_embeddings", 131072) # Llama 3.1 supports up to 131k
except Exception:
max_ctx = 131072 # Use maximum possible
logger.info(f"Model max context: {max_ctx} tokens")
# Detect if this is a Chain of Thinking request
is_cot_request = ("chain-of-thinking" in prompt.lower() or
"chain of thinking" in prompt.lower() or
"Return exactly this JSON array" in prompt or
("verbatim" in prompt.lower() and "json array" in prompt.lower()))
# MAXIMIZE GENERATION TOKENS - use most of context for generation
if is_cot_request:
# For CoT, use MAXIMUM possible generation tokens
gen_max_new_tokens = 16384 # Very high limit for complete responses
min_tokens = 2000 # High minimum to force complete generation
# Allow most of context for input
allowed_input_tokens = max_ctx - gen_max_new_tokens - 100 # Small safety buffer
logger.info(f"CoT REQUEST - MAXIMIZED: min_tokens={min_tokens}, max_new_tokens={gen_max_new_tokens}, input_limit={allowed_input_tokens}")
else:
# Standard requests
gen_max_new_tokens = 8192
min_tokens = 200
allowed_input_tokens = max_ctx - gen_max_new_tokens - 100
# Tokenize the input with safe truncation
inputs = model_manager.tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=allowed_input_tokens
)
# Move inputs to the same device as the model
if model_manager.device == "cuda:0":
model_device = next(model_manager.model.parameters()).device
inputs = {k: v.to(model_device) for k, v in inputs.items()}
# Generate response with MAXIMUM settings
with torch.no_grad():
outputs = model_manager.model.generate(
**inputs,
max_new_tokens=gen_max_new_tokens,
min_new_tokens=min_tokens,
temperature=temperature,
top_p=0.95,
do_sample=True,
num_beams=1,
pad_token_id=model_manager.tokenizer.eos_token_id,
eos_token_id=model_manager.tokenizer.eos_token_id,
early_stopping=False, # Never stop early
repetition_penalty=1.05,
no_repeat_ngram_size=0,
length_penalty=1.0,
# Force generation to continue
use_cache=True
)
# Decode the response
generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Log generation details for debugging
input_length = inputs['input_ids'].shape[1]
output_length = outputs[0].shape[0]
generated_length = output_length - input_length
logger.info(f"Generation stats - Input: {input_length} tokens, Generated: {generated_length} tokens, Min required: {min_tokens}")
if generated_length < min_tokens:
logger.warning(f"Generated {generated_length} tokens but minimum was {min_tokens} - response may be truncated")
# Post-decode guard: if a top-level JSON array closes, trim to the first full array
# This helps prevent trailing prose like 'assistant' or 'Message'.
try:
# Track both bracket and brace depth to find first complete JSON structure
bracket_depth = 0 # [ ]
brace_depth = 0 # { }
in_string = False
escape_next = False
start_idx = None
end_idx = None
for i, ch in enumerate(generated_text):
# Handle string escaping
if escape_next:
escape_next = False
continue
if ch == '\\':
escape_next = True
continue
# Track if we're inside a string
if ch == '"' and not escape_next:
in_string = not in_string
continue
# Only count brackets/braces outside of strings
if not in_string:
if ch == '[':
if bracket_depth == 0 and brace_depth == 0 and start_idx is None:
start_idx = i
bracket_depth += 1
elif ch == ']':
bracket_depth = max(0, bracket_depth - 1)
if bracket_depth == 0 and brace_depth == 0 and start_idx is not None:
end_idx = i
break
elif ch == '{':
brace_depth += 1
elif ch == '}':
brace_depth = max(0, brace_depth - 1)
if start_idx is not None and end_idx is not None and end_idx > start_idx:
# Extract just the complete JSON array
json_text = generated_text[start_idx:end_idx+1]
logger.info(f"Extracted complete JSON array of length {len(json_text)}")
generated_text = json_text
elif start_idx is not None:
# Found start but no end - response was truncated
logger.warning("JSON array started but never closed - response truncated")
# Try to extract what we have and let the client handle it
generated_text = generated_text[start_idx:]
except Exception as e:
logger.warning(f"Error in JSON extraction: {e}")
pass
# Extract just the assistant's response
if "<|start_header_id|>assistant<|end_header_id|>" in generated_text:
response = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
else:
# Better fallback: look for the start of actual content (JSON or text)
import re
# Look for JSON array or object start
json_match = re.search(r'(\[|\{)', generated_text)
if json_match and json_match.start() > len(formatted_prompt) // 2:
response = generated_text[json_match.start():].strip()
else:
# Look for the end of the prompt pattern
prompt_end_patterns = [
"<|end_header_id|>",
"<|eot_id|>",
"assistant",
"\n\n"
]
response = generated_text
for pattern in prompt_end_patterns:
if pattern in generated_text:
parts = generated_text.split(pattern)
if len(parts) > 1:
# Take the last substantial part
candidate = parts[-1].strip()
if len(candidate) > 20: # Ensure it's not too short
response = candidate
break
# Ultimate fallback - just return everything after a reasonable point
if response == generated_text:
# Skip approximately the prompt length but be conservative
skip_chars = min(len(formatted_prompt) // 2, len(generated_text) // 3)
response = generated_text[skip_chars:].strip()
logger.info(f"Generated response length: {len(response)} characters")
return response
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return f"Error: {str(e)}"
def respond(message, history, temperature):
"""Gradio interface function for chat"""
response = generate_response(message, temperature)
# Update history
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
return history, ""
# Create the Gradio interface
with gr.Blocks(title="Question Generation API") as demo:
gr.Markdown("# Simple LLM API")
gr.Markdown("Send a prompt and get a response. No templates, just direct model interaction.")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
label="Chat",
type="messages",
height=400
)
msg = gr.Textbox(
label="Message",
placeholder="Enter your prompt here...",
lines=3
)
with gr.Row():
submit = gr.Button("Send", variant="primary")
clear = gr.Button("Clear")
with gr.Column(scale=1):
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature",
info="Higher = more creative"
)
gr.Markdown("""
### API Usage
This model accepts any prompt and returns a response.
For JSON responses, include instructions in your prompt like:
- "Return as a JSON array"
- "Format as JSON"
- "List as JSON"
The model will follow your instructions.
""")
# Set up event handlers
submit.click(respond, [msg, chatbot, temperature], [chatbot, msg])
msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg])
clear.click(lambda: ([], ""), outputs=[chatbot, msg])
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
) |