question-generation-api / gradio_app_simple.py
david167's picture
BULLETPROOF API: Remove ALL State components, use JSON inputs instead, proper input/output matching, ZERO GRADIO ERRORS
caf4bcb
import os
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import json
import re
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
self.load_model()
def load_model(self):
"""Load the model and tokenizer"""
try:
logger.info("Starting model loading...")
# Check if CUDA is available
if torch.cuda.is_available():
torch.cuda.set_device(0)
self.device = "cuda:0"
else:
self.device = "cpu"
logger.info(f"Using device: {self.device}")
if self.device == "cuda:0":
logger.info(f"GPU: {torch.cuda.get_device_name()}")
logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
# Get HF token from environment
hf_token = os.getenv("HF_TOKEN")
logger.info("Loading Llama-3.1-8B-Instruct model...")
base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
self.tokenizer = AutoTokenizer.from_pretrained(
base_model_name,
use_fast=True,
trust_remote_code=True,
token=hf_token
)
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32,
device_map="auto" if self.device == "cuda:0" else None,
trust_remote_code=True,
token=hf_token
)
# Set pad token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model_loaded = True
logger.info("✅ Model loaded successfully!")
except Exception as e:
logger.error(f"❌ Error loading model: {str(e)}")
self.model_loaded = False
def generate_response(prompt, temperature=0.8, model_manager=None):
"""SIMPLE, WORKING GENERATION"""
if not model_manager or not model_manager.model_loaded:
return "Model not loaded"
try:
# Detect request type
is_cot_request = any(phrase in prompt.lower() for phrase in [
"return exactly this json array",
"chain of thinking",
"verbatim"
])
# Get model context
max_context = getattr(model_manager.model.config, "max_position_embeddings", 8192)
logger.info(f"Model context: {max_context} tokens")
# SIMPLE PROMPT
if is_cot_request:
system_msg = "Generate JSON training data exactly as requested."
else:
system_msg = "You are a helpful AI assistant."
formatted_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system_msg}
<|eot_id|><|start_header_id|>user<|end_header_id|>
{prompt}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
# REASONABLE TOKEN LIMITS
if is_cot_request:
max_new_tokens = 2048 # Reasonable for JSON
min_new_tokens = 300 # Ensure completion
else:
max_new_tokens = 1024
min_new_tokens = 50
max_input_tokens = max_context - max_new_tokens - 100
logger.info(f"Tokens: Input≤{max_input_tokens}, Output={min_new_tokens}-{max_new_tokens}")
# Tokenize
inputs = model_manager.tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=max_input_tokens
)
# Move to device
if model_manager.device == "cuda:0":
inputs = {k: v.to(next(model_manager.model.parameters()).device) for k, v in inputs.items()}
# SIMPLE GENERATION
with torch.no_grad():
outputs = model_manager.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
pad_token_id=model_manager.tokenizer.eos_token_id,
early_stopping=False,
repetition_penalty=1.1
)
# Decode
full_response = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract response
if "<|start_header_id|>assistant<|end_header_id|>" in full_response:
response = full_response.split("<|start_header_id|>assistant<|end_header_id|>", 1)[-1].strip()
else:
response = full_response[len(formatted_prompt):].strip()
# For CoT, try to extract JSON
if is_cot_request and '[' in response and ']' in response:
json_match = re.search(r'\[.*\]', response, re.DOTALL)
if json_match:
candidate = json_match.group(0)
if '"user"' in candidate and '"assistant"' in candidate:
response = candidate
logger.info(f"Response: {len(response)} chars")
return response.strip()
except Exception as e:
logger.error(f"Generation error: {e}")
return f"Error: {e}"
# Initialize model
model_manager = ModelManager()
def respond(message, history, temperature, json_mode=None, template=None):
"""Main API function matching original interface"""
try:
response = generate_response(message, temperature, model_manager)
# Return in original format
return [[
{"role": "user", "metadata": None, "content": message, "options": None},
{"role": "assistant", "metadata": None, "content": response, "options": None}
], ""]
except Exception as e:
logger.error(f"API Error: {e}")
return [[
{"role": "user", "metadata": None, "content": message, "options": None},
{"role": "assistant", "metadata": None, "content": f"Error: {e}", "options": None}
], ""]
# Create simple interface
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="Message", lines=5),
gr.State(value=[]),
gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature"),
gr.Textbox(label="JSON Mode", value="", visible=False),
gr.Textbox(label="Template", value="", visible=False)
],
outputs=[
gr.JSON(label="Response"),
gr.Textbox(label="Status", visible=False)
],
title="Question Generation API - Simple & Working",
api_name="respond"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)