Signvrse's picture
Update app.py (#1)
782ca05 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gc
import logging
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('gloss_generator.log'), # Log to a file
logging.StreamHandler() # Also log to console
]
)
logger = logging.getLogger(__name__)
# Log startup
logger.info("Starting ASL Gloss Generator application")
torch.cuda.empty_cache() # Clear cached memory
gc.collect()
model_name = "Signvrse/Glosser_Gemma2_2B"
logger.info(f"Loading tokenizer for model: {model_name}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info("Tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load tokenizer: {str(e)}")
raise
torch.cuda.empty_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
logger.info(f"Loading model: {model_name}")
try:
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
load_in_4bit=True,
torch_dtype=torch.float16,
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
# Define the prompt template for glossification
gloss_prompt = """{system_message}
User: {user_message}
Assistant: {assistant_message}"""
# # System message for the assistant
# system_message = """You are a glossification assistant that translates English sentences into sign language gloss (e.g., ASL gloss). Ensure translations preserve semantic meaning and align with the provided context (e.g., question, narrative, command). Output glosses in uppercase with appropriate sign language syntax."""
# # User prompt that incorporates context and content
# user_prompt_template = """Given the <ENGLISH_SENTENCE> and the <CONTEXT>, generate the corresponding sign language gloss, ensuring semantic accuracy.
# <CONTEXT>
# {context}
# </CONTEXT>
# <ENGLISH_SENTENCE>
# {content}
# </ENGLISH_SENTENCE>"""
system_message = """You are a glossification assistant translating English sentences into Kenyan Sign Language (KSL) gloss. Follow these guidelines:
- Output glosses in UPPERCASE.
- Use KSL word order: subject-verb-object.
- Preserve semantic meaning, adapting for cultural and contextual nuances.
- Handle out-of-vocabulary words by using the most common gloss or a descriptive equivalent.
- Avoid repetition errors by ensuring diverse vocabulary usage.
- If context is provided, use it to inform the gloss (e.g., question, narrative, command).
If the context or sentence is ambiguous, prioritize semantic clarity and standard KSL conventions.
- Mask Proper Noun in <PN> Mask eg. How are you John to be. YOU HOW <JOHN>"""
user_prompt_template = """Given the <ENGLISH_SENTENCE> and the <CONTEXT>, generate the corresponding sign language gloss, ensuring semantic accuracy.
<CONTEXT>
{context}
</CONTEXT>
<ENGLISH_SENTENCE>
{content}
</ENGLISH_SENTENCE>"""
# Define the function that will be called by the Gradio interface
def generate_gloss(english_sentence, context="None"):
"""
Generates a sign language gloss for a given English sentence and context.
"""
logger.info(f"Generating gloss for sentence: '{english_sentence}' with context: '{context}'")
start_time = time.time()
try:
# Format the input prompt with BOS token to match batch script format
formatted_prompt = tokenizer.bos_token + gloss_prompt.format(
system_message=system_message,
user_message=user_prompt_template.format(context=context, content=english_sentence),
assistant_message=""
)
logger.debug(f"Formatted prompt: {formatted_prompt}")
# Tokenize the input and move to the correct device
inputs = tokenizer([formatted_prompt], return_tensors="pt").to(device)
logger.debug("Input tokenized and moved to device")
# Generate the gloss using the model
outputs = model.generate(**inputs, max_new_tokens=64, use_cache=True)
logger.debug("Model generation completed")
# Decode the full generated text
generated_text = tokenizer.batch_decode(outputs)[0]
logger.info(f"Raw generated text: {generated_text}")
# --- Start of improved gloss extraction logic ---
# Find the part of the string that starts after the assistant's prompt
assistant_start_marker = "Assistant:"
assistant_start_index = generated_text.find(assistant_start_marker)
inferred_gloss = ""
if assistant_start_index != -1:
# Extract the text after the "Assistant:" marker
gloss_text = generated_text[assistant_start_index + len(assistant_start_marker):].strip()
# Remove any trailing EOS tokens
if gloss_text.endswith(tokenizer.eos_token):
inferred_gloss = gloss_text[:-len(tokenizer.eos_token)].strip()
else:
inferred_gloss = gloss_text.strip()
else:
logger.warning("Assistant marker not found, using raw generated text")
inferred_gloss = generated_text.strip()
if not inferred_gloss:
logger.warning("Generated gloss is empty")
inferred_gloss = "Error: Generated gloss is empty"
elapsed_time = time.time() - start_time
logger.info(f"Gloss generated successfully in {elapsed_time:.2f} seconds: {inferred_gloss}")
return inferred_gloss
except Exception as e:
logger.error(f"Error generating gloss: {str(e)}")
return f"Error: {str(e)}"
# Create the Gradio interface
with gr.Blocks(title="ASL Gloss Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ASL Gloss Generator
Enter an English sentence and optional context to get its sign language gloss.
"""
)
with gr.Row(variant="panel"):
# Input section
with gr.Column(scale=1):
text_input = gr.Textbox(
lines=3,
label="English Sentence",
placeholder="Enter your sentence here...",
interactive=True
)
context_input = gr.Textbox(
lines=1,
label="Context (Optional)",
placeholder="e.g., 'A question', 'A narrative', 'A command'",
interactive=True
)
generate_button = gr.Button("Generate Gloss", variant="primary")
# Output section
with gr.Column(scale=1):
output_text = gr.Textbox(
lines=5,
label="Generated Gloss",
interactive=False
)
# Link the button to the function
generate_button.click(
fn=generate_gloss,
inputs=[text_input, context_input],
outputs=output_text
)
# Launch the app
logger.info("Launching Gradio interface")
try:
demo.launch(share=False) # Set share=False for local testing; change to True for public link
logger.info("Gradio interface launched successfully")
except Exception as e:
logger.error(f"Failed to launch Gradio interface: {str(e)}")
raise