project / app.py
hugsanaa's picture
Update app.py
ee83dd5 verified
from datasets import Dataset
import gradio as gr
import pandas as pd
import re
import io
import numpy as np
import pandas as pd
import os
import torch
import spaces
from transformers import (AutoModelForCausalLM,
AutoTokenizer,
QuantoConfig,
BitsAndBytesConfig,
TrainingArguments,
pipeline,
logging)
from tqdm import tqdm
import huggingface_hub
import os
token = os.getenv("HUGGINGFACE_TOKEN")
huggingface_hub.login(token)
model_path = "hugsanaa/mymodel"
device = "cuda"
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
# device_map="cuda",
# torch_dtype=torch.float16,
quantization_config=bnb_config,
token =token
)
model.to(device)
model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(model_path,
trust_remote_code=True,
max_length=512,
padding_side="left",
add_eos_token=True,
token =token
)
tokenizer.pad_token = tokenizer.eos_token
# Preprocessing function
def preprocess_text(text):
"""
Apply all preprocessing steps sequentially:
1. Lowercase conversion
2. Punctuation removal
3. Whitespace removal
"""
text = text.lower()
text = re.sub(r'([.,?!;])\1+', r'\1', text) # Replace multiple occurrences with a single one
text = re.sub(r'[^a-zA-Z0-9.,?!; ]', '', text) # Remove all other special characters
text = " ".join(text.split()) # Remove extra whitespaces
return text
# Generate prompt
def generate_test_prompt(data_point):
return f"""
[INST] You are a specialized AI designed to detect illegal discourse related to drug purchases and sales. Below is a text and you are required to distinguish between instances of illegal activity involving drugs and legal drug mentioning like in medical or non-threatening context.
Choose only one label "Flag" if the text is related to illegal activities related to drugs or "Not Flag" if the text refers to drugs in a non-threatening, legal, or medical context. [/INST]
Text: [{data_point["Processed Text"]}] =
""".strip()
@spaces.GPU(duration=120)
# Prediction function
def predict(X_test):
y_pred = []
# Convert DataFrame to Hugging Face Dataset
dataset = Dataset.from_pandas(X_test)
# Tokenize text and move tensors to the same device
inputs = tokenizer(X_test["Processed Text"].tolist(), return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()} # Move tensors to device
# Generate outputs
outputs = model.generate(**inputs, pad_token_id=tokenizer.eos_token_id)
# Decode outputs
results = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
for result in results:
match = re.search(r"label:\s*\[(.*?)\]", result.lower().strip())
extracted_text = match.group(1).strip().lower() if match else "No Label Found"
if "not flag" in extracted_text:
y_pred.append("The text refers to legal activity. (النص يشير إلى سياق قانوني)")
elif "flag" in extracted_text:
y_pred.append("Illegal activity detected! (تم اكتشاف نشاط غير قانوني)")
else:
y_pred.append("Insufficient information to determine legality. (لا توجد معلومات كافية لتحديد ما إذا كان النص قانونيًا أم لا)")
return list(y_pred)
# Function to process multiple CSVs
def process_csv(input_data):
all_results = [] # Store results from all files
for file in input_data:
if file.name.endswith('.csv'):
df = pd.read_csv(file.name)
elif file.name.endswith(('.xls', '.xlsx')):
df = pd.read_excel(file.name)
if 'Text' not in df.columns:
return f"Error: {file.name} must contain a 'Text' column.", None
# Preprocess text
df['Processed Text'] = df['Text'].apply(preprocess_text)
# Generate prompts
X_test = pd.DataFrame(df.apply(generate_test_prompt, axis=1), columns=["Processed Text"])
# Get predictions
predictions = predict(X_test) # Ensure list format
df['Prediction'] = predictions
# df['Prediction'] = predict(X_test)
# Store filename
df.insert(0, "Filename", os.path.basename(file.name))
df = df[['Filename', 'Text', 'Prediction']]
df = df.rename(columns={'Filename': "Filename (اسم الملف)", 'Text': "Text (النص)", 'Prediction': "Prediction (نوع النص)"})
print(df.head())
all_results.append(df)
# Combine all files into one DataFrame
final_df = pd.concat(all_results, ignore_index=True)
files = export_excel(final_df)
return final_df, [file for file in files] # Return dataframe and filenames for download links
def process_text(text):
# Preprocess
processed_text = preprocess_text(text)
# Generate prompt
prompt_series = pd.Series({"Processed Text": processed_text})
prompt = generate_test_prompt(prompt_series)
# Wrap in DataFrame and ensure correct column name
X_test = pd.DataFrame([prompt], columns=["Processed Text"])
if "Processed Text" not in X_test.columns:
return "Error: 'Processed Text' column missing in X_test"
# Predict
prediction = predict(X_test)
return prediction[0] if prediction else "No prediction returned"
def export_excel(results):
exported_files = []
# Split the dataframe based on unique filenames
unique_values = results["Filename (اسم الملف)"].unique()
for i, filename in enumerate(unique_values):
df_subset = results[results["Filename (اسم الملف)"] == filename] # Filter by filename
predicted_filename = f"Predicted_{os.path.basename(filename.split('.')[0])}.xlsx"
df_subset.to_excel(predicted_filename, index=False)
exported_files.append(predicted_filename)
# yield gr.File(value=predicted_filename, visible=True)
return exported_files
# Example Data for UI
example_data = [
["We provide Cannabis at unbeatable prices. Stealth shipping ensured.", "Illegal activity detected! (تم اكتشاف نشاط غير قانوني)"],
["Recent studies show that MDMA has potential benefits for mental health.", "The text refers to legal activity. (النص يشير إلى سياق قانوني)"]
]
# Create Gradio Interface
with gr.Blocks() as demo:
# Header
gr.Image("image.png", show_download_button = False, show_fullscreen_button = False, show_share_button = False, label = False, show_label = False, container = False)
gr.Markdown("This program currently supports English text only. You can enter one record or file of records. / يدعم هذا البرنامج حاليًا اللغة الإنجليزية فقط. يمكنك إدخال سجل واحد أو ملف يحتوي على سجلات متعددة.")
# Predict Single Text
gr.Markdown("## **Check a post for illegal activity / مراجعة المنشور للتحقق من الأنشطة غير القانونية**")
# gr.Markdown("### **Sample records for testing the system / أمثلة عينة لاختبار النظام**")
# gr.Dataframe(value=example_data, headers=["Text (النص)", "Expected output by the system (نوع النص المتوقع من النظام)"], datatype=["str", "str"], interactive=False)
gr.Markdown("""
**How to Use This Demo / كيفية استخدام هذا العرض التوضيحي:**
1. Write an English text sample in the text below. / **اكتب نموذجًا نصيًا باللغة الإنجليزية أدناه.**
2. Press the submit button. / **اضغط على زر الإرسال.**
3. Wait until the interface ends processing and provides the output. /**انتظر حتى ينتهي البرنامج من المعالجة ويعرض النتيجة.**
4. You can use the following sample records for testing the system. / **يمكنك استخدام الأمثلة العينة التالية لاختبار النظام.**
- **Text (النص):** We provide Cannabis at unbeatable prices. Stealth shipping ensured. - **Expected output by the system (نوع النص المتوقع من النظام):** Illegal activity detected! (تم اكتشاف نشاط غير قانوني)
- **Text (النص):** Recent studies show that MDMA has potential benefits for mental health. - **Expected output by the system (نوع النص المتوقع من النظام):** The text refers to legal activity. (النص يشير إلى سياق قانوني)
""")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(lines=5, placeholder="Write the text here... (اكتب النص هنا)", label="Enter Text (أدخل النص)")
text_button = gr.Button("Submit / إرسال")
with gr.Column():
text_output = gr.Textbox(lines=5, label="Prediction Result (نتيجة التنبؤ)")
text_button.click(process_text, text_input, text_output)
gr.Markdown('<hr style="height: 10px; border: none; color: #000000; background-color: #333; margin-top: 20px; margin-bottom: 20px; width: 100%;">')
gr.Markdown('<hr style="height: 10px; border: none; color: #000000; background-color: #333; margin-top: 20px; margin-bottom: 20px; width: 100%;">')
# Predict Multiple Files
gr.Markdown("## **Check a file for illegal activity / مراجعة الملف للتحقق من الأنشطة غير القانونية**")
gr.Markdown("""
**How to Use This Demo / كيفية استخدام هذا العرض التوضيحي:**
1. Prepare an Excel or CSV file with multiple posts in a column named 'Text'. / **قم بإعداد ملف اكسل يحتوي على عدة منشورات تحت عمود بعنوان "Text".**
2. Upload your file. / **قم بتحميل الملف الخاص بك.**
3. Press submit and wait for the results. / **اضغط على زر الإرسال وانتظر النتائج.**
""")
with gr.Row():
with gr.Column():
file_upload = gr.Files(label="Upload File (تحميل ملف)", file_types=[".csv", ".xlsx"])
file_button = gr.Button("Submit / إرسال")
with gr.Column():
file_output = gr.Dataframe(headers=["Filename (اسم الملف)", "Text (النص)", "Prediction (نوع النص)"], datatype=["str", "str", "str"])
download_files = gr.Files(interactive=False, visible=True, label="Download Results (تحميل النتائج)")
file_button.click(process_csv, file_upload, outputs=[file_output, download_files])
# Launch App
demo.launch()