import gradio as gr import json import time import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline MODEL_PATH = 'berkeruveyik/toxic-speech-finetune-with-gemma-3-1b-v1' # Load model and tokenizer loaded_model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype='auto', device_map='auto', attn_implementation='eager' ) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) loaded_model_pipeline = pipeline( 'text-generation', model=loaded_model, tokenizer=tokenizer ) @spaces.GPU def pred_on_text(input_text): """Generate prediction from input text""" start_time = time.time() raw_output = loaded_model_pipeline( text_inputs=[{'role': 'user', 'content': input_text}], max_new_tokens=256 ) end_time = time.time() total_time = round(end_time - start_time, 4) generated_text = raw_output[0]['generated_text'][1]['content'] return generated_text, raw_output, total_time def parse_generated_text(text): """Parse the generated text and format it nicely""" try: data = json.loads(text) return data except: try: text = text.strip() if text.startswith('{') and text.endswith('}'): data = eval(text) return data except: pass return {"raw_output": text} def format_output(input_text, parsed_output, total_time): """Format output as readable text with each field on new line""" output_lines = [] output_lines.append(f"📝 Input: {input_text}") output_lines.append("") output_lines.append("━" * 50) output_lines.append("") if "is_toxic" in parsed_output: emoji = "🚨" if parsed_output["is_toxic"] else "✅" output_lines.append(f"{emoji} is_toxic: {parsed_output['is_toxic']}") if "label" in parsed_output: output_lines.append(f"🏷️ label: {parsed_output['label']}") if "tags" in parsed_output: output_lines.append(f"🔖 tags: {parsed_output['tags']}") if "reason" in parsed_output: output_lines.append(f"💬 reason: {parsed_output['reason']}") if "severity" in parsed_output: output_lines.append(f"⚡ severity: {parsed_output['severity']}") if "raw_output" in parsed_output: output_lines.append(f"📄 raw_output: {parsed_output['raw_output']}") output_lines.append("") output_lines.append("━" * 50) output_lines.append(f"⏱️ processing_time: {total_time} seconds") return "\n".join(output_lines) def gradio_predict(input_text): """Wrapper function for Gradio""" if not input_text.strip(): return "Please enter some text." generated_text, raw_output, total_time = pred_on_text(input_text) parsed_output = parse_generated_text(generated_text) formatted_output = format_output(input_text, parsed_output, total_time) return formatted_output # Gradio interface demo = gr.Interface( fn=gradio_predict, inputs=gr.Textbox( label="Input Text", placeholder="Enter your text here...", lines=3 ), outputs=gr.Textbox( label="Model Output", lines=12 ), title="🤖 Toxic Speech Classifier", description="Analyze whether a given text contains toxic, insulting, or harmful language using a fine-tuned Gemma3 model.", examples=[ ["You are absolutely worthless and no one will ever love you."], ["I hope you get hit by a bus, you disgusting excuse for a person."], ["The weather today is really nice, I enjoyed my walk in the park."], ["Shut up you brainless moron, nobody asked for your stupid opinion."], ["Thank you for your help, I really appreciate everything you did."], ["You are such a pathetic loser, get out of my sight."], ["I just finished reading a great book, it was very inspiring."], ], theme=gr.themes.Soft() ) if __name__ == "__main__": demo.launch()