question-generation-api / gradio_app_new.py
david167's picture
Fix Gradio interface: Remove chatbot format issues, add proper API endpoint structure
b2df124
import os
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import json
import re
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
self.load_model()
def load_model(self):
"""Load the model and tokenizer"""
try:
logger.info("Starting model loading...")
# Check if CUDA is available
if torch.cuda.is_available():
torch.cuda.set_device(0)
self.device = "cuda:0"
else:
self.device = "cpu"
logger.info(f"Using device: {self.device}")
if self.device == "cuda:0":
logger.info(f"GPU: {torch.cuda.get_device_name()}")
logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
# Get HF token from environment
hf_token = os.getenv("HF_TOKEN")
logger.info("Loading Llama-3.1-8B-Instruct model...")
base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
self.tokenizer = AutoTokenizer.from_pretrained(
base_model_name,
use_fast=True,
trust_remote_code=True,
token=hf_token
)
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32,
device_map="auto" if self.device == "cuda:0" else None,
trust_remote_code=True,
token=hf_token
)
# Set pad token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model_loaded = True
logger.info("✅ Model loaded successfully!")
except Exception as e:
logger.error(f"❌ Error loading model: {str(e)}")
self.model_loaded = False
def generate_response(prompt, temperature=0.8, model_manager=None):
"""ELEGANT AI ARCHITECT SOLUTION - Clean, simple, effective"""
if not model_manager or not model_manager.model_loaded:
return "Model not loaded"
try:
# Detect request type
is_cot_request = any(phrase in prompt.lower() for phrase in [
"return exactly this json array",
"chain of thinking",
"verbatim",
"json array (no other text)"
])
# Get actual model context
max_context = getattr(model_manager.model.config, "max_position_embeddings", 8192)
logger.info(f"Model context: {max_context} tokens")
# SIMPLE, CLEAR PROMPT FORMATTING
if is_cot_request:
system_msg = "You are an expert at generating JSON training data. Return only valid JSON arrays as requested, no additional text."
else:
system_msg = "You are a helpful AI assistant generating high-quality training data."
formatted_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system_msg}
<|eot_id|><|start_header_id|>user<|end_header_id|>
{prompt}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
# SMART TOKEN ALLOCATION
if is_cot_request:
# CoT needs substantial output for complete JSON
max_new_tokens = 3000 # Generous but not excessive
min_new_tokens = 500 # Ensure JSON completion
else:
max_new_tokens = 1500
min_new_tokens = 50
# Reserve space for input
max_input_tokens = max_context - max_new_tokens - 100
logger.info(f"Token plan: Input≤{max_input_tokens}, Output={min_new_tokens}-{max_new_tokens}")
# Tokenize
inputs = model_manager.tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=max_input_tokens
)
# Move to device
if model_manager.device == "cuda:0":
inputs = {k: v.to(next(model_manager.model.parameters()).device) for k, v in inputs.items()}
# CLEAN GENERATION
with torch.no_grad():
outputs = model_manager.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
pad_token_id=model_manager.tokenizer.eos_token_id,
early_stopping=False,
repetition_penalty=1.1
)
# Decode
full_response = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Log stats
input_len = inputs['input_ids'].shape[1]
output_len = outputs[0].shape[0]
generated_len = output_len - input_len
logger.info(f"Generated {generated_len} tokens (min was {min_new_tokens})")
# CLEAN EXTRACTION
if "<|start_header_id|>assistant<|end_header_id|>" in full_response:
response = full_response.split("<|start_header_id|>assistant<|end_header_id|>", 1)[-1].strip()
else:
# Fallback
response = full_response[len(formatted_prompt):].strip()
# For CoT, extract clean JSON if possible
if is_cot_request and '[' in response and ']' in response:
# Find the most complete JSON array
json_pattern = r'\[(?:[^[\]]+|\[[^\]]*\])*\]'
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
# Pick the longest match (most complete)
best_match = max(matches, key=len)
# Verify it has reasonable content
if '"user"' in best_match and '"assistant"' in best_match:
logger.info(f"Extracted JSON: {len(best_match)} chars")
response = best_match
logger.info(f"Final response: {len(response)} chars")
return response.strip()
except Exception as e:
logger.error(f"Generation error: {e}")
return f"Error: {e}"
# Initialize model
model_manager = ModelManager()
def respond(message, history, temperature):
"""Gradio interface function"""
try:
response = generate_response(message, temperature, model_manager)
history.append([message, response])
return history, ""
except Exception as e:
logger.error(f"Error in respond: {e}")
history.append([message, f"Error: {e}"])
return history, ""
# Create Gradio interface
with gr.Blocks(title="Question Generation API") as demo:
gr.Markdown("# Question Generation API")
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(label="Message", placeholder="Enter your prompt...")
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature")
with gr.Row():
submit = gr.Button("Submit", variant="primary")
clear = gr.Button("Clear")
submit.click(respond, [msg, chatbot, temperature], [chatbot, msg])
msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg])
clear.click(lambda: ([], ""), outputs=[chatbot, msg])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)