question-generation-api / gradio_app.py
david167's picture
Enable concurrency_limit=10 for better parallel processing
0b2f34f
import os
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import json
import re
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
self.load_model()
def load_model(self):
"""Load the model and tokenizer"""
try:
logger.info("Starting model loading...")
# Check if CUDA is available
if torch.cuda.is_available():
torch.cuda.set_device(0)
self.device = "cuda:0"
else:
self.device = "cpu"
logger.info(f"Using device: {self.device}")
if self.device == "cuda:0":
logger.info(f"GPU: {torch.cuda.get_device_name()}")
logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
# Get HF token from environment
hf_token = os.getenv("HF_TOKEN")
logger.info("Loading Llama-3.1-8B-Instruct model...")
base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
self.tokenizer = AutoTokenizer.from_pretrained(
base_model_name,
use_fast=True,
trust_remote_code=True,
token=hf_token
)
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32,
device_map="auto" if self.device == "cuda:0" else None,
trust_remote_code=True,
token=hf_token,
attn_implementation="eager" # Use eager attention (compatible)
)
# Set pad token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model_loaded = True
logger.info("βœ… Model loaded successfully!")
except Exception as e:
logger.error(f"❌ Error loading model: {str(e)}")
self.model_loaded = False
def generate_response(prompt, temperature=0.8):
"""ZERO TRUNCATION GENERATION - Never cut anything!"""
global model_manager
if not model_manager or not model_manager.model_loaded:
return "Model not loaded"
try:
# Detect CoT requests
is_cot = any(phrase in prompt.lower() for phrase in [
"return exactly this json array",
"chain of thinking",
"verbatim"
])
logger.info(f"🎯 Request type: {'CoT' if is_cot else 'Standard'}")
# Simple system message
if is_cot:
system = "You are an expert at generating JSON training data exactly as requested."
else:
system = "You are a helpful AI assistant."
# Format prompt
formatted = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system}
<|eot_id|><|start_header_id|>user<|end_header_id|>
{prompt}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
# Optimized token limits for speed
if is_cot:
max_new = 1500 # Reduced for speed
min_new = 400 # Reduced minimum
else:
max_new = 800 # Significantly reduced for speed
min_new = 50 # Lower minimum
max_input = 6000 # Safe input limit
logger.info(f"πŸ”’ Token allocation: Input≀{max_input}, Output={min_new}-{max_new}")
# Tokenize
inputs = model_manager.tokenizer(
formatted,
return_tensors="pt",
truncation=True,
max_length=max_input
)
# Move to device
if model_manager.device == "cuda:0":
inputs = {k: v.to(next(model_manager.model.parameters()).device) for k, v in inputs.items()}
logger.info("πŸš€ Starting generation...")
# Generate with generous parameters
with torch.no_grad():
outputs = model_manager.model.generate(
**inputs,
max_new_tokens=max_new,
min_new_tokens=min_new,
temperature=temperature,
top_p=0.9,
do_sample=True,
num_beams=1, # Greedy search for speed
pad_token_id=model_manager.tokenizer.eos_token_id,
early_stopping=True, # Enable early stopping for speed
repetition_penalty=1.1,
use_cache=True
)
# Decode the COMPLETE response
full_response = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"πŸ“ Full response length: {len(full_response)} chars")
logger.info(f"πŸ“ Response preview: {full_response[:200]}...")
# ZERO TRUNCATION EXTRACTION - Find content intelligently but never cut
response = full_response
# Look for the assistant response marker
assistant_marker = "<|start_header_id|>assistant<|end_header_id|>"
if assistant_marker in full_response:
# Find the position after the marker
marker_pos = full_response.find(assistant_marker)
if marker_pos != -1:
# Start after the marker + some whitespace
start_pos = marker_pos + len(assistant_marker)
# Skip any immediate whitespace/newlines
while start_pos < len(full_response) and full_response[start_pos] in ' \n\r\t':
start_pos += 1
if start_pos < len(full_response):
response = full_response[start_pos:]
logger.info(f"βœ‚οΈ Extracted after assistant marker: {len(response)} chars")
else:
logger.info("πŸ”„ Marker found but no content after, using full response")
else:
logger.info("πŸ”„ Marker search failed, using full response")
else:
logger.info("πŸ”„ No assistant marker found, using full response")
# For CoT, if we have a JSON array, extract it cleanly
if is_cot and '[' in response and ']' in response:
# Find the outermost JSON array
first_bracket = response.find('[')
last_bracket = response.rfind(']')
if first_bracket != -1 and last_bracket != -1 and last_bracket > first_bracket:
json_candidate = response[first_bracket:last_bracket+1]
# Validate it contains the expected structure
if '"user"' in json_candidate and '"assistant"' in json_candidate:
# Count the objects to make sure we have multiple items
user_count = json_candidate.count('"user"')
if user_count >= 2: # Should have at least 2 user/assistant pairs
response = json_candidate
logger.info(f"🎯 Extracted JSON array with {user_count} items: {len(response)} chars")
else:
logger.info(f"⚠️ JSON array has only {user_count} items, using full response")
else:
logger.info("⚠️ JSON candidate failed validation, using full response")
# Final response
response = response.strip()
logger.info(f"βœ… FINAL response: {len(response)} chars")
logger.info(f"🎬 Starts with: {response[:150]}...")
logger.info(f"🎭 Ends with: ...{response[-150:]}")
return response
except Exception as e:
logger.error(f"πŸ’₯ Generation error: {e}")
return f"Error: {e}"
# Initialize model ONCE
model_manager = ModelManager()
def api_respond(message, history_str, temperature, json_mode, template):
"""ZERO TRUNCATION API - Pure content, no wrappers"""
try:
logger.info(f"πŸ“¨ API Request: {len(message)} chars, temp={temperature}")
response = generate_response(message, temperature)
logger.info(f"πŸ“€ API Response: {len(response)} chars")
return response
except Exception as e:
logger.error(f"πŸ’₯ API Error: {e}")
return f"Error: {e}"
# BULLETPROOF GRADIO INTERFACE
demo = gr.Interface(
fn=api_respond,
inputs=[
gr.Textbox(label="Message", lines=8, placeholder="Enter your prompt here..."),
gr.Textbox(label="History", value="[]", visible=False),
gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature"),
gr.Textbox(label="JSON Mode", value="", visible=False),
gr.Textbox(label="Template", value="", visible=False)
],
outputs=gr.Textbox(label="Response", lines=20, max_lines=50),
title="🎯 Question Generation API - ZERO TRUNCATION",
description="Rebuilt from scratch with ZERO text cutting. Generates complete responses every time.",
api_name="respond"
)
if __name__ == "__main__":
# Enable queue with concurrency limit of 10
demo.queue(
default_concurrency_limit=10, # Handle 10 concurrent requests
max_size=100 # Allow up to 100 requests in queue
).launch(server_name="0.0.0.0", server_port=7860, share=False)