| import gradio as gr |
| import json |
| import time |
| import spaces |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
| MODEL_PATH = 'berkeruveyik/toxic-speech-finetune-with-gemma-3-1b-v1' |
|
|
| |
| loaded_model = AutoModelForCausalLM.from_pretrained( |
| MODEL_PATH, |
| torch_dtype='auto', |
| device_map='auto', |
| attn_implementation='eager' |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
|
| loaded_model_pipeline = pipeline( |
| 'text-generation', |
| model=loaded_model, |
| tokenizer=tokenizer |
| ) |
|
|
| @spaces.GPU |
| def pred_on_text(input_text): |
| """Generate prediction from input text""" |
| start_time = time.time() |
|
|
| raw_output = loaded_model_pipeline( |
| text_inputs=[{'role': 'user', 'content': input_text}], |
| max_new_tokens=256 |
| ) |
|
|
| end_time = time.time() |
| total_time = round(end_time - start_time, 4) |
|
|
| generated_text = raw_output[0]['generated_text'][1]['content'] |
|
|
| return generated_text, raw_output, total_time |
|
|
| def parse_generated_text(text): |
| """Parse the generated text and format it nicely""" |
| try: |
| data = json.loads(text) |
| return data |
| except: |
| try: |
| text = text.strip() |
| if text.startswith('{') and text.endswith('}'): |
| data = eval(text) |
| return data |
| except: |
| pass |
| return {"raw_output": text} |
|
|
| def format_output(input_text, parsed_output, total_time): |
| """Format output as readable text with each field on new line""" |
| output_lines = [] |
| output_lines.append(f"π Input: {input_text}") |
| output_lines.append("") |
| output_lines.append("β" * 50) |
| output_lines.append("") |
|
|
| if "is_toxic" in parsed_output: |
| emoji = "π¨" if parsed_output["is_toxic"] else "β
" |
| output_lines.append(f"{emoji} is_toxic: {parsed_output['is_toxic']}") |
|
|
| if "label" in parsed_output: |
| output_lines.append(f"π·οΈ label: {parsed_output['label']}") |
|
|
| if "tags" in parsed_output: |
| output_lines.append(f"π tags: {parsed_output['tags']}") |
|
|
| if "reason" in parsed_output: |
| output_lines.append(f"π¬ reason: {parsed_output['reason']}") |
|
|
| if "severity" in parsed_output: |
| output_lines.append(f"β‘ severity: {parsed_output['severity']}") |
|
|
| if "raw_output" in parsed_output: |
| output_lines.append(f"π raw_output: {parsed_output['raw_output']}") |
|
|
| output_lines.append("") |
| output_lines.append("β" * 50) |
| output_lines.append(f"β±οΈ processing_time: {total_time} seconds") |
|
|
| return "\n".join(output_lines) |
|
|
| def gradio_predict(input_text): |
| """Wrapper function for Gradio""" |
| if not input_text.strip(): |
| return "Please enter some text." |
|
|
| generated_text, raw_output, total_time = pred_on_text(input_text) |
| parsed_output = parse_generated_text(generated_text) |
|
|
| formatted_output = format_output(input_text, parsed_output, total_time) |
|
|
| return formatted_output |
|
|
| |
| demo = gr.Interface( |
| fn=gradio_predict, |
| inputs=gr.Textbox( |
| label="Input Text", |
| placeholder="Enter your text here...", |
| lines=3 |
| ), |
| outputs=gr.Textbox( |
| label="Model Output", |
| lines=12 |
| ), |
| title="π€ Toxic Speech Classifier", |
| description="Analyze whether a given text contains toxic, insulting, or harmful language using a fine-tuned Gemma3 model.", |
| examples=[ |
| ["You are absolutely worthless and no one will ever love you."], |
| ["I hope you get hit by a bus, you disgusting excuse for a person."], |
| ["The weather today is really nice, I enjoyed my walk in the park."], |
| ["Shut up you brainless moron, nobody asked for your stupid opinion."], |
| ["Thank you for your help, I really appreciate everything you did."], |
| ["You are such a pathetic loser, get out of my sight."], |
| ["I just finished reading a great book, it was very inspiring."], |
| ], |
| theme=gr.themes.Soft() |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|