|
|
import os |
|
|
os.environ.setdefault("OMP_NUM_THREADS", "1") |
|
|
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") |
|
|
os.environ.setdefault("MKL_NUM_THREADS", "1") |
|
|
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") |
|
|
|
|
|
import json |
|
|
import os |
|
|
import gradio as gr |
|
|
from typing import Optional, Dict, Any, Union |
|
|
from PIL import Image |
|
|
from pydantic import BaseModel |
|
|
import logging |
|
|
from config import Config |
|
|
|
|
|
|
|
|
try: |
|
|
from llama_cpp import Llama, LlamaGrammar, LlamaRAMCache |
|
|
LLAMA_CPP_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
print(f"Warning: llama-cpp-python not available: {e}") |
|
|
LLAMA_CPP_AVAILABLE = False |
|
|
Llama = None |
|
|
LlamaGrammar = None |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
HUGGINGFACE_HUB_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
print(f"Warning: huggingface_hub not available: {e}") |
|
|
HUGGINGFACE_HUB_AVAILABLE = False |
|
|
hf_hub_download = None |
|
|
|
|
|
|
|
|
log_level = getattr(logging, Config.LOG_LEVEL.upper()) |
|
|
logging.basicConfig(level=log_level) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
llama_logger = logging.getLogger('llama_cpp') |
|
|
llama_logger.setLevel(logging.WARNING) |
|
|
|
|
|
class StructuredOutputRequest(BaseModel): |
|
|
prompt: str |
|
|
image: Optional[str] = None |
|
|
json_schema: Dict[str, Any] |
|
|
|
|
|
class LLMClient: |
|
|
def __init__(self): |
|
|
""" |
|
|
Initialize client for working with local GGUF model via llama-cpp-python |
|
|
""" |
|
|
self.model_path = Config.get_model_path() |
|
|
logger.info(f"Using model: {self.model_path}") |
|
|
|
|
|
self.llm = None |
|
|
|
|
|
self._initialize_model() |
|
|
|
|
|
def _download_model_if_needed(self) -> str: |
|
|
"""Download model from Hugging Face if it doesn't exist locally""" |
|
|
if os.path.exists(self.model_path): |
|
|
logger.info(f"Model already exists at: {self.model_path}") |
|
|
return self.model_path |
|
|
|
|
|
|
|
|
|
|
|
if os.getenv('DOCKER_CONTAINER', 'false').lower() == 'true': |
|
|
|
|
|
alternative_paths = [ |
|
|
f"/app/models/{Config.MODEL_FILENAME}", |
|
|
f"./models/{Config.MODEL_FILENAME}", |
|
|
f"/models/{Config.MODEL_FILENAME}", |
|
|
f"/app/{Config.MODEL_FILENAME}" |
|
|
] |
|
|
|
|
|
for alt_path in alternative_paths: |
|
|
if os.path.exists(alt_path): |
|
|
logger.info(f"Found model at alternative location: {alt_path}") |
|
|
return alt_path |
|
|
|
|
|
|
|
|
models_dir = "/app/models" |
|
|
if os.path.exists(models_dir): |
|
|
files = os.listdir(models_dir) |
|
|
logger.error(f"Contents of {models_dir}: {files}") |
|
|
else: |
|
|
logger.error(f"Directory {models_dir} does not exist") |
|
|
|
|
|
|
|
|
logger.warning("Model not found in expected locations, attempting download...") |
|
|
|
|
|
if not HUGGINGFACE_HUB_AVAILABLE: |
|
|
raise ImportError("huggingface_hub is not available. Please install it to download models.") |
|
|
|
|
|
logger.info(f"Downloading model {Config.MODEL_REPO}/{Config.MODEL_FILENAME}...") |
|
|
|
|
|
|
|
|
models_dir = Config.get_models_dir() |
|
|
os.makedirs(models_dir, exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id=Config.MODEL_REPO, |
|
|
filename=Config.MODEL_FILENAME, |
|
|
local_dir=models_dir, |
|
|
token=Config.HUGGINGFACE_TOKEN if Config.HUGGINGFACE_TOKEN else None |
|
|
) |
|
|
|
|
|
logger.info(f"Model downloaded to: {model_path}") |
|
|
return model_path |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to download model: {e}") |
|
|
raise |
|
|
|
|
|
def _initialize_model(self): |
|
|
"""Initialize local GGUF model""" |
|
|
try: |
|
|
if not LLAMA_CPP_AVAILABLE: |
|
|
raise ImportError("llama-cpp-python is not available. Please check installation.") |
|
|
|
|
|
logger.info("Loading local model...") |
|
|
|
|
|
|
|
|
model_path = self._download_model_if_needed() |
|
|
|
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
|
|
|
|
|
|
file_size = os.path.getsize(model_path) |
|
|
if file_size < 1024 * 1024: |
|
|
raise ValueError(f"Model file seems corrupted or incomplete. Size: {file_size} bytes") |
|
|
|
|
|
logger.info(f"Model file verified. Size: {file_size / (1024**3):.2f} GB") |
|
|
|
|
|
|
|
|
logger.info("Initializing Llama model...") |
|
|
self.llm = Llama( |
|
|
model_path=model_path, |
|
|
n_ctx=Config.N_CTX, |
|
|
n_batch=Config.N_BATCH, |
|
|
n_gpu_layers=Config.N_GPU_LAYERS, |
|
|
use_mlock=Config.USE_MLOCK, |
|
|
use_mmap=Config.USE_MMAP, |
|
|
vocab_only=False, |
|
|
f16_kv=Config.F16_KV, |
|
|
logits_all=False, |
|
|
embedding=False, |
|
|
n_threads=Config.N_THREADS, |
|
|
last_n_tokens_size=64, |
|
|
lora_base=None, |
|
|
lora_path=None, |
|
|
seed=Config.SEED, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Model successfully loaded and initialized") |
|
|
|
|
|
|
|
|
from time import time |
|
|
logger.info("Testing model with simple prompt...") |
|
|
start_time = time() |
|
|
test_response = self.llm("Hello", max_tokens=1, temperature=1.0, top_k=64, top_p=0.95, min_p=0.0) |
|
|
logger.info(f"Model test time: {time() - start_time:.2f} seconds, response: {test_response}") |
|
|
logger.info("Model test successful") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error initializing model: {e}") |
|
|
|
|
|
if "Failed to load model from file" in str(e): |
|
|
logger.error("This error usually indicates:") |
|
|
logger.error("1. Model file is corrupted or incomplete") |
|
|
logger.error("2. llama-cpp-python version is incompatible with the model") |
|
|
logger.error("3. Insufficient memory to load the model") |
|
|
logger.error(f"4. Model path: {self.model_path}") |
|
|
raise |
|
|
|
|
|
def _validate_json_schema(self, schema: str) -> Dict[str, Any]: |
|
|
"""Validate and parse JSON schema""" |
|
|
try: |
|
|
parsed_schema = json.loads(schema) |
|
|
return parsed_schema |
|
|
except json.JSONDecodeError as e: |
|
|
raise ValueError(f"Invalid JSON schema: {e}") |
|
|
|
|
|
def _format_prompt_with_schema(self, prompt: str, json_schema: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Format prompt for structured output generation using Gemma chat format |
|
|
""" |
|
|
schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
formatted_prompt = f"""<bos><start_of_turn>user |
|
|
{prompt} |
|
|
|
|
|
Please respond in strict accordance with the following JSON schema: |
|
|
|
|
|
```json |
|
|
{schema_str} |
|
|
``` |
|
|
|
|
|
Return ONLY valid JSON without additional comments or explanations.<end_of_turn> |
|
|
<start_of_turn>model |
|
|
""" |
|
|
|
|
|
return formatted_prompt |
|
|
|
|
|
def _format_gemma_chat(self, messages: list) -> str: |
|
|
""" |
|
|
Format messages in Gemma chat format |
|
|
|
|
|
Args: |
|
|
messages: List of dicts with 'role' and 'content' keys |
|
|
role can be 'user' or 'model' |
|
|
""" |
|
|
formatted_parts = ["<bos>"] |
|
|
|
|
|
for message in messages: |
|
|
role = message.get('role', 'user') |
|
|
content = message.get('content', '') |
|
|
|
|
|
if role not in ['user', 'model']: |
|
|
role = 'user' |
|
|
|
|
|
formatted_parts.append(f"<start_of_turn>{role}") |
|
|
formatted_parts.append(content) |
|
|
formatted_parts.append("<end_of_turn>") |
|
|
|
|
|
|
|
|
formatted_parts.append("<start_of_turn>model") |
|
|
|
|
|
return "\n".join(formatted_parts) |
|
|
|
|
|
def generate_chat_response(self, messages: list, max_tokens: int = None) -> str: |
|
|
""" |
|
|
Generate response using Gemma chat format |
|
|
|
|
|
Args: |
|
|
messages: List of message dicts with 'role' and 'content' keys |
|
|
max_tokens: Maximum tokens for generation |
|
|
|
|
|
Returns: |
|
|
Generated response text |
|
|
""" |
|
|
if not messages: |
|
|
raise ValueError("Messages list cannot be empty") |
|
|
|
|
|
|
|
|
formatted_prompt = self._format_gemma_chat(messages) |
|
|
|
|
|
|
|
|
generation_params = { |
|
|
"max_tokens": max_tokens or Config.MAX_NEW_TOKENS, |
|
|
"temperature": Config.TEMPERATURE, |
|
|
"top_k": 64, |
|
|
"top_p": 0.95, |
|
|
"min_p": 0.0, |
|
|
"echo": False, |
|
|
"stop": ["<end_of_turn>", "<start_of_turn>", "<bos>"] |
|
|
} |
|
|
|
|
|
|
|
|
response = self.llm(formatted_prompt, **generation_params) |
|
|
generated_text = response['choices'][0]['text'].strip() |
|
|
|
|
|
return generated_text |
|
|
|
|
|
def generate_structured_response(self, |
|
|
prompt: str, |
|
|
json_schema: Union[str, Dict[str, Any]], |
|
|
image: Optional[Image.Image] = None, |
|
|
use_grammar: bool = True) -> Dict[str, Any]: |
|
|
""" |
|
|
Generate structured response from local GGUF model |
|
|
""" |
|
|
try: |
|
|
|
|
|
if isinstance(json_schema, str): |
|
|
parsed_schema = self._validate_json_schema(json_schema) |
|
|
else: |
|
|
parsed_schema = json_schema |
|
|
|
|
|
|
|
|
formatted_prompt = self._format_prompt_with_schema(prompt, parsed_schema) |
|
|
|
|
|
|
|
|
if image is not None: |
|
|
logger.warning("Image processing is not supported with this local model") |
|
|
|
|
|
|
|
|
logger.info(f"Generating response... (Grammar: {'Enabled' if use_grammar else 'Disabled'})") |
|
|
|
|
|
|
|
|
grammar = None |
|
|
if use_grammar and LLAMA_CPP_AVAILABLE and LlamaGrammar is not None: |
|
|
try: |
|
|
gbnf_grammar = _json_schema_to_gbnf(parsed_schema, "root") |
|
|
grammar = LlamaGrammar.from_string(gbnf_grammar) |
|
|
logger.info("Grammar successfully created from JSON schema") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to create grammar: {e}. Falling back to non-grammar mode.") |
|
|
use_grammar = False |
|
|
|
|
|
|
|
|
generation_params = { |
|
|
"max_tokens": Config.MAX_NEW_TOKENS, |
|
|
"temperature": Config.TEMPERATURE, |
|
|
"top_k": 64, |
|
|
"top_p": 0.95, |
|
|
"min_p": 0.0, |
|
|
"echo": False |
|
|
} |
|
|
|
|
|
|
|
|
if use_grammar and grammar is not None: |
|
|
generation_params["grammar"] = grammar |
|
|
|
|
|
simple_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n" |
|
|
response = self.llm(simple_prompt, **generation_params) |
|
|
else: |
|
|
|
|
|
generation_params["stop"] = ["<end_of_turn>", "<start_of_turn>", "<bos>"] |
|
|
response = self.llm(formatted_prompt, **generation_params) |
|
|
|
|
|
|
|
|
generated_text = response['choices'][0]['text'] |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
json_start = generated_text.find('{') |
|
|
json_end = generated_text.rfind('}') + 1 |
|
|
|
|
|
if json_start != -1 and json_end > json_start: |
|
|
json_str = generated_text[json_start:json_end] |
|
|
parsed_response = json.loads(json_str) |
|
|
return parsed_response |
|
|
else: |
|
|
return { |
|
|
"error": "Could not find JSON in model response", |
|
|
"raw_response": generated_text |
|
|
} |
|
|
|
|
|
except json.JSONDecodeError as e: |
|
|
return { |
|
|
"error": f"JSON parsing error: {e}", |
|
|
"raw_response": generated_text |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error: {e}") |
|
|
return { |
|
|
"error": f"Generation error: {str(e)}" |
|
|
} |
|
|
|
|
|
def _json_schema_to_gbnf(schema: Dict[str, Any], root_name: str = "root") -> str: |
|
|
"""Convert JSON schema to GBNF (Backus-Naur Form) grammar for structured output""" |
|
|
rules = {} |
|
|
|
|
|
def add_rule(name: str, definition: str): |
|
|
if name not in rules: |
|
|
rules[name] = f"{name} ::= {definition}" |
|
|
|
|
|
def process_type(schema_part: Dict[str, Any], type_name: str = "value") -> str: |
|
|
if "type" not in schema_part: |
|
|
|
|
|
return "string" |
|
|
|
|
|
schema_type = schema_part["type"] |
|
|
|
|
|
if schema_type == "object": |
|
|
|
|
|
properties = schema_part.get("properties", {}) |
|
|
required = schema_part.get("required", []) |
|
|
|
|
|
if not properties: |
|
|
add_rule(type_name, '"{" ws "}"') |
|
|
return type_name |
|
|
|
|
|
|
|
|
property_rules = [] |
|
|
|
|
|
for prop_name, prop_schema in properties.items(): |
|
|
prop_type_name = f"{type_name}_{prop_name}" |
|
|
prop_type = process_type(prop_schema, prop_type_name) |
|
|
property_rules.append(f'"\\"" "{prop_name}" "\\"" ws ":" ws {prop_type}') |
|
|
|
|
|
|
|
|
|
|
|
if len(property_rules) == 1: |
|
|
object_def = f'"{{" ws {property_rules[0]} ws "}}"' |
|
|
else: |
|
|
properties_joined = ' ws "," ws '.join(property_rules) |
|
|
object_def = f'"{{" ws {properties_joined} ws "}}"' |
|
|
|
|
|
add_rule(type_name, object_def) |
|
|
return type_name |
|
|
|
|
|
elif schema_type == "array": |
|
|
|
|
|
items_schema = schema_part.get("items", {}) |
|
|
items_type_name = f"{type_name}_items" |
|
|
item_type = process_type(items_schema, f"{type_name}_item") |
|
|
|
|
|
|
|
|
add_rule(items_type_name, f"{item_type} (ws \",\" ws {item_type})*") |
|
|
add_rule(type_name, f'"[" ws ({items_type_name})? ws "]"') |
|
|
return type_name |
|
|
|
|
|
elif schema_type == "string": |
|
|
|
|
|
if "enum" in schema_part: |
|
|
enum_values = schema_part["enum"] |
|
|
enum_options = ' | '.join([f'"\\"" "{val}" "\\""' for val in enum_values]) |
|
|
add_rule(type_name, enum_options) |
|
|
return type_name |
|
|
else: |
|
|
return "string" |
|
|
|
|
|
elif schema_type == "number" or schema_type == "integer": |
|
|
return "number" |
|
|
|
|
|
elif schema_type == "boolean": |
|
|
return "boolean" |
|
|
|
|
|
else: |
|
|
return "string" |
|
|
|
|
|
|
|
|
basic_rules_data = [ |
|
|
('ws', '[ \\t\\n]*'), |
|
|
('string', '"\\"" char* "\\""'), |
|
|
('char', '[^"\\\\] | "\\\\" (["\\\\bfnrt] | "u" hex hex hex hex)'), |
|
|
('hex', '[0-9a-fA-F]'), |
|
|
('number', '"-"? ("0" | [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?'), |
|
|
('boolean', '"true" | "false"'), |
|
|
('null', '"null"') |
|
|
] |
|
|
|
|
|
for rule_name, rule_def in basic_rules_data: |
|
|
add_rule(rule_name, rule_def) |
|
|
|
|
|
|
|
|
process_type(schema, root_name) |
|
|
|
|
|
|
|
|
return "\n".join(rules.values()) |
|
|
|
|
|
def test_grammar_generation(json_schema_str: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Test grammar generation without running the full model |
|
|
""" |
|
|
try: |
|
|
parsed_schema = llm_client._validate_json_schema(json_schema_str) |
|
|
gbnf_grammar = _json_schema_to_gbnf(parsed_schema, "root") |
|
|
return { |
|
|
"success": True, |
|
|
"grammar": gbnf_grammar, |
|
|
"schema": parsed_schema |
|
|
} |
|
|
except Exception as e: |
|
|
return { |
|
|
"success": False, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
logger.info("Initializing LLM client...") |
|
|
try: |
|
|
llm_client = LLMClient() |
|
|
logger.info("LLM client successfully initialized") |
|
|
except Exception as e: |
|
|
logger.error(f"Error initializing LLM client: {e}") |
|
|
llm_client = None |
|
|
|
|
|
def process_request(prompt: str, |
|
|
json_schema: str, |
|
|
image: Optional[Image.Image] = None, |
|
|
use_grammar: bool = True) -> str: |
|
|
""" |
|
|
Process request through Gradio interface |
|
|
""" |
|
|
if llm_client is None: |
|
|
return json.dumps({ |
|
|
"error": "LLM client not initialized", |
|
|
"details": "Check logs for detailed error information" |
|
|
}, ensure_ascii=False, indent=2) |
|
|
|
|
|
if not prompt.strip(): |
|
|
return json.dumps({"error": "Prompt cannot be empty"}, ensure_ascii=False, indent=2) |
|
|
|
|
|
if not json_schema.strip(): |
|
|
return json.dumps({"error": "JSON schema cannot be empty"}, ensure_ascii=False, indent=2) |
|
|
|
|
|
result = llm_client.generate_structured_response(prompt, json_schema, image, use_grammar) |
|
|
return json.dumps(result, ensure_ascii=False, indent=2) |
|
|
|
|
|
def test_gemma_chat(messages_text: str) -> str: |
|
|
""" |
|
|
Test Gemma chat format with example conversation |
|
|
""" |
|
|
if llm_client is None: |
|
|
return "Error: LLM client not initialized" |
|
|
|
|
|
try: |
|
|
|
|
|
messages = [] |
|
|
for line in messages_text.strip().split('\n'): |
|
|
if ':' in line: |
|
|
role, content = line.split(':', 1) |
|
|
role = role.strip().lower() |
|
|
content = content.strip() |
|
|
if role in ['user', 'model']: |
|
|
messages.append({"role": role, "content": content}) |
|
|
|
|
|
if not messages: |
|
|
|
|
|
messages = [ |
|
|
{"role": "user", "content": "Hello!"}, |
|
|
{"role": "model", "content": "Hey there!"}, |
|
|
{"role": "user", "content": "What is 1+1?"} |
|
|
] |
|
|
|
|
|
|
|
|
formatted_prompt = llm_client._format_gemma_chat(messages) |
|
|
|
|
|
|
|
|
response = llm_client.generate_chat_response(messages, max_tokens=100) |
|
|
|
|
|
return f"Formatted prompt:\n{formatted_prompt}\n\nGenerated response:\n{response}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
example_schema = """{ |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"summary": { |
|
|
"type": "string", |
|
|
"description": "Brief summary of the response" |
|
|
}, |
|
|
"sentiment": { |
|
|
"type": "string", |
|
|
"enum": ["positive", "negative", "neutral"], |
|
|
"description": "Emotional tone" |
|
|
}, |
|
|
"confidence": { |
|
|
"type": "number", |
|
|
"minimum": 0, |
|
|
"maximum": 1, |
|
|
"description": "Confidence level in the response" |
|
|
}, |
|
|
"keywords": { |
|
|
"type": "array", |
|
|
"items": { |
|
|
"type": "string" |
|
|
}, |
|
|
"description": "Key words" |
|
|
} |
|
|
}, |
|
|
"required": ["summary", "sentiment", "confidence"] |
|
|
}""" |
|
|
|
|
|
example_prompt = "Analyze the following text and provide a structured assessment: 'The company's new product received enthusiastic user reviews. Sales exceeded all expectations by 150%.'" |
|
|
|
|
|
def create_gradio_interface(): |
|
|
"""Create Gradio interface""" |
|
|
|
|
|
with gr.Blocks(title="LLM Structured Output", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# π€ LLM with Structured Output") |
|
|
gr.Markdown(f"Application for generating structured responses using model **{Config.MODEL_REPO}/{Config.MODEL_FILENAME}**") |
|
|
|
|
|
|
|
|
if llm_client is None: |
|
|
gr.Markdown("β οΈ **Warning**: Model not loaded. Check configuration and restart the application.") |
|
|
else: |
|
|
gr.Markdown("β
**Status**: Model successfully loaded and ready to work") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("π§ Structured Output"): |
|
|
create_structured_output_tab() |
|
|
|
|
|
with gr.TabItem("π¬ Gemma Chat Format"): |
|
|
create_gemma_chat_tab() |
|
|
|
|
|
|
|
|
gr.Markdown(f""" |
|
|
## βΉοΈ Model Information |
|
|
|
|
|
- **Model**: {Config.MODEL_REPO}/{Config.MODEL_FILENAME} |
|
|
- **Local path**: {Config.MODEL_PATH} |
|
|
- **Context window**: {Config.N_CTX} tokens |
|
|
- **Batch size**: {Config.N_BATCH} |
|
|
- **GPU layers**: {Config.N_GPU_LAYERS if Config.N_GPU_LAYERS >= 0 else "All"} |
|
|
- **CPU threads**: {Config.N_THREADS} |
|
|
- **Maximum response length**: {Config.MAX_NEW_TOKENS} tokens |
|
|
- **Temperature**: {Config.TEMPERATURE} |
|
|
- **Memory lock**: {"Enabled" if Config.USE_MLOCK else "Disabled"} |
|
|
- **Memory mapping**: {"Enabled" if Config.USE_MMAP else "Disabled"} |
|
|
|
|
|
π‘ **Tips**: |
|
|
- Use clear and specific JSON schemas for better results |
|
|
- Enable Grammar (GBNF) mode for more precise JSON structure enforcement |
|
|
- Grammar mode uses schema-based constraints to guarantee valid JSON output |
|
|
- Disable Grammar mode for more flexible text generation with schema guidance |
|
|
|
|
|
π **Grammar Features**: |
|
|
- Automatic conversion of JSON Schema to GBNF grammar |
|
|
- Strict enforcement of JSON structure during generation |
|
|
- Support for objects, arrays, strings, numbers, booleans, and enums |
|
|
- Improved consistency and reliability of structured outputs |
|
|
|
|
|
π **Gemma Format Features**: |
|
|
- Uses proper Gemma chat tokens: `<bos>`, `<start_of_turn>`, `<end_of_turn>` |
|
|
- Supports multi-turn conversations with user/model roles |
|
|
- Compatible with Gemma model's expected input format |
|
|
- Improved response quality with proper token structure |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
def create_structured_output_tab(): |
|
|
"""Create structured output tab""" |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt_input = gr.Textbox( |
|
|
label="Prompt for model", |
|
|
placeholder="Enter your request...", |
|
|
lines=5, |
|
|
value=example_prompt |
|
|
) |
|
|
|
|
|
image_input = gr.Image( |
|
|
label="Image (optional, for multimodal models)", |
|
|
type="pil" |
|
|
) |
|
|
|
|
|
schema_input = gr.Textbox( |
|
|
label="JSON schema for response structure", |
|
|
placeholder="Enter JSON schema...", |
|
|
lines=15, |
|
|
value=example_schema |
|
|
) |
|
|
|
|
|
grammar_checkbox = gr.Checkbox( |
|
|
label="π Use Grammar (GBNF) Mode", |
|
|
value=True, |
|
|
info="Enable grammar-based structured output for more precise JSON generation" |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("Generate Response", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output = gr.Textbox( |
|
|
label="Structured Response", |
|
|
lines=20, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=process_request, |
|
|
inputs=[prompt_input, schema_input, image_input, grammar_checkbox], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("## π Usage Examples") |
|
|
|
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
[ |
|
|
"Describe today's weather in New York", |
|
|
"""{ |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"temperature": {"type": "number"}, |
|
|
"description": {"type": "string"}, |
|
|
"humidity": {"type": "number"} |
|
|
} |
|
|
}""", |
|
|
None |
|
|
], |
|
|
[ |
|
|
"Create a Python learning plan for one month", |
|
|
"""{ |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"weeks": { |
|
|
"type": "array", |
|
|
"items": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"week_number": {"type": "integer"}, |
|
|
"topics": {"type": "array", "items": {"type": "string"}}, |
|
|
"practice_hours": {"type": "number"} |
|
|
} |
|
|
} |
|
|
}, |
|
|
"total_hours": {"type": "number"} |
|
|
} |
|
|
}""", |
|
|
None |
|
|
] |
|
|
], |
|
|
inputs=[prompt_input, schema_input, image_input] |
|
|
) |
|
|
|
|
|
def create_gemma_chat_tab(): |
|
|
"""Create Gemma chat format demonstration tab""" |
|
|
gr.Markdown("## π¬ Gemma Chat Format Demo") |
|
|
gr.Markdown("This tab demonstrates the Gemma chat format with `<bos>`, `<start_of_turn>`, and `<end_of_turn>` tokens.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
messages_input = gr.Textbox( |
|
|
label="Conversation Messages (format: role: message per line)", |
|
|
placeholder="user: Hello!\nmodel: Hey there!\nuser: What is 1+1?", |
|
|
lines=8, |
|
|
value="user: Hello!\nmodel: Hey there!\nuser: What is 1+1?" |
|
|
) |
|
|
|
|
|
test_btn = gr.Button("Test Gemma Format", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
chat_output = gr.Textbox( |
|
|
label="Formatted Prompt and Response", |
|
|
lines=15, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
test_btn.click( |
|
|
fn=test_gemma_chat, |
|
|
inputs=messages_input, |
|
|
outputs=chat_output |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
### π Format Explanation |
|
|
|
|
|
The Gemma chat format uses special tokens to structure conversations: |
|
|
- `<bos>` - Beginning of sequence |
|
|
- `<start_of_turn>user` - Start user message |
|
|
- `<end_of_turn>` - End current message |
|
|
- `<start_of_turn>model` - Start model response |
|
|
|
|
|
**Example structure:** |
|
|
``` |
|
|
<bos><start_of_turn>user |
|
|
Hello!<end_of_turn> |
|
|
<start_of_turn>model |
|
|
Hey there!<end_of_turn> |
|
|
<start_of_turn>user |
|
|
What is 1+1?<end_of_turn> |
|
|
<start_of_turn>model |
|
|
``` |
|
|
|
|
|
This format is now used for both structured output and regular chat generation. |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
demo = create_gradio_interface() |
|
|
demo.launch( |
|
|
server_name=Config.HOST, |
|
|
server_port=Config.GRADIO_PORT, |
|
|
share=False, |
|
|
debug=False |
|
|
) |
|
|
|