dejanseo commited on
Commit
432c5a1
·
verified ·
1 Parent(s): 4cc104c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -66
app.py CHANGED
@@ -3,6 +3,11 @@ import torch
3
  import torch.nn.functional as F
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import re
 
 
 
 
 
6
 
7
  # Set the page configuration
8
  st.set_page_config(
@@ -15,7 +20,7 @@ st.set_page_config(
15
  st.logo(
16
  image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
17
  link="https://dejan.ai/",
18
- size="large"
19
  )
20
 
21
  # Font styling
@@ -28,88 +33,171 @@ st.markdown("""
28
  </style>
29
  """, unsafe_allow_html=True)
30
 
31
- # Load model and tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  MODEL_NAME = "dejanseo/ai-detection"
33
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
34
 
35
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
36
- model = AutoModelForSequenceClassification.from_pretrained(
37
- MODEL_NAME,
38
- device_map="auto",
39
- torch_dtype=torch.float32 # ensure safe fallback on CPU
40
- )
41
- model.eval()
42
 
43
  # Static settings
44
  LABELS = ["AI Content", "Human Content"]
45
  COLORS = ["#ffe5e5", "#e6ffe6"] # light red, light green
46
 
47
- # Regex-based sentence splitter
48
  def sent_tokenize(text):
49
- return re.split(r'(?<=[.!?]) +', text.strip())
 
 
 
50
 
51
- def split_into_chunks(text, max_length=512):
52
  sentences = sent_tokenize(text)
53
- chunks, current_chunk, current_len = [], [], 0
 
 
 
 
 
54
  for sent in sentences:
55
- token_len = len(tokenizer.tokenize(sent))
56
- if current_len + token_len <= max_length - 2:
57
- current_chunk.append(sent)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  current_len += token_len
59
  else:
60
- if current_chunk:
61
- chunks.append(" ".join(current_chunk))
62
- current_chunk = [sent]
 
 
63
  current_len = token_len
64
- if current_chunk:
65
- chunks.append(" ".join(current_chunk))
 
 
 
66
  return chunks
67
 
68
- # UI
69
  st.title("AI Article Detection")
70
- text = st.text_area("Enter text to classify", height=100)
71
 
72
- if st.button("Classify"):
73
- if not text.strip():
74
  st.warning("Please enter some text.")
75
  else:
76
- with st.spinner("Analyzing..."):
77
- chunks = split_into_chunks(text)
78
- inputs = tokenizer(chunks, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
79
-
80
- with torch.no_grad():
81
- outputs = model(**inputs)
82
- logits = outputs.logits
83
- probs = F.softmax(logits, dim=-1)
84
- preds = torch.argmax(probs, dim=-1)
85
-
86
- chunk_results = []
87
- for i, chunk in enumerate(chunks):
88
- pred = int(preds[i].item())
89
- chunk_results.append({
90
- "text": chunk,
91
- "label": LABELS[pred],
92
- "color": COLORS[pred],
93
- "conf": probs[i][pred].item() * 100,
94
- })
95
-
96
- avg_probs = torch.mean(probs, dim=0).tolist()
97
- final_class = int(torch.argmax(torch.tensor(avg_probs)).item())
98
- final_label = LABELS[final_class]
99
- final_conf = avg_probs[final_class] * 100
100
-
101
- st.subheader("📊 Final Prediction")
102
- st.markdown(
103
- f"<div style='background-color:{COLORS[final_class]}; padding:1rem; border-radius:0.5rem'>"
104
- f"<b>{final_label}</b> ({final_conf:.1f}%)</div>",
105
- unsafe_allow_html=True
106
- )
107
-
108
- with st.expander("See per-chunk predictions"):
109
- for result in chunk_results:
110
- st.markdown(
111
- f"<div title='Confidence: {result['conf']:.1f}%' "
112
- f"style='background-color:{result['color']}; padding:0.75rem; margin-bottom:0.5rem; border-radius:0.5rem'>"
113
- f"{result['text']}</div>",
114
- unsafe_allow_html=True
115
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch.nn.functional as F
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import re
6
+ import logging # Optional: Add logging for better debugging
7
+
8
+ # Set up logging (optional but helpful)
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
  # Set the page configuration
13
  st.set_page_config(
 
20
  st.logo(
21
  image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
22
  link="https://dejan.ai/",
23
+ # size="large" # 'size' is not a valid argument for st.logo as of Streamlit 1.34 - remove or adjust if needed
24
  )
25
 
26
  # Font styling
 
33
  </style>
34
  """, unsafe_allow_html=True)
35
 
36
+ @st.cache_resource # Cache the model and tokenizer to avoid reloading on every interaction
37
+ def load_model_and_tokenizer(model_name):
38
+ """Loads the model and tokenizer."""
39
+ logger.info(f"Loading tokenizer: {model_name}")
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
41
+
42
+ # Determine device
43
+ device_type = "cuda" if torch.cuda.is_available() else "cpu"
44
+ # Use bfloat16 if available on CUDA for potential speedup/memory saving, else float32
45
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
46
+ logger.info(f"Using device: {device_type} with dtype: {dtype}")
47
+
48
+ logger.info(f"Loading model: {model_name}")
49
+ # Load model onto CPU first, then move to target device
50
+ model = AutoModelForSequenceClassification.from_pretrained(
51
+ model_name,
52
+ torch_dtype=dtype # Use the determined dtype
53
+ # Removed device_map="auto"
54
+ )
55
+ logger.info("Moving model to target device...")
56
+ model.to(torch.device(device_type)) # Move the entire model to the target device
57
+ model.eval() # Set model to evaluation mode
58
+ logger.info("Model loaded successfully.")
59
+ return tokenizer, model, torch.device(device_type)
60
+
61
+ # Load model and tokenizer using the cached function
62
  MODEL_NAME = "dejanseo/ai-detection"
63
+ try:
64
+ tokenizer, model, device = load_model_and_tokenizer(MODEL_NAME)
65
+ except Exception as e:
66
+ st.error(f"Error loading model: {e}")
67
+ logger.error(f"Failed to load model or tokenizer: {e}", exc_info=True)
68
+ st.stop() # Stop execution if model loading fails
69
 
 
 
 
 
 
 
 
70
 
71
  # Static settings
72
  LABELS = ["AI Content", "Human Content"]
73
  COLORS = ["#ffe5e5", "#e6ffe6"] # light red, light green
74
 
75
+ # Regex-based sentence splitter (improved slightly for robustness)
76
  def sent_tokenize(text):
77
+ # Split by '.', '!', '?' followed by space(s) or end of string
78
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
79
+ # Filter out empty strings that might result from splitting
80
+ return [s for s in sentences if s]
81
 
82
+ def split_into_chunks(text, tokenizer, max_length=512):
83
  sentences = sent_tokenize(text)
84
+ if not sentences:
85
+ return [] # Handle empty input after tokenization
86
+
87
+ chunks, current_chunk_sentences, current_len = [], [], 0
88
+ max_tokens = max_length - 2 # Account for [CLS] and [SEP] tokens
89
+
90
  for sent in sentences:
91
+ # Use tokenizer.encode to get accurate token count (more reliable than tokenize)
92
+ token_ids = tokenizer.encode(sent, add_special_tokens=False)
93
+ token_len = len(token_ids)
94
+
95
+ if token_len > max_tokens:
96
+ # Sentence is too long even by itself, handle appropriately
97
+ # Option 1: Truncate the sentence (simplest)
98
+ logger.warning(f"Sentence truncated as it exceeds max_length: '{sent[:100]}...'")
99
+ truncated_sent = tokenizer.decode(token_ids[:max_tokens])
100
+ # If there was a previous chunk, add it first
101
+ if current_chunk_sentences:
102
+ chunks.append(" ".join(current_chunk_sentences))
103
+ chunks.append(truncated_sent) # Add the single truncated sentence as its own chunk
104
+ current_chunk_sentences, current_len = [], 0 # Reset chunk
105
+ continue # Move to the next sentence
106
+
107
+ if current_len + token_len <= max_tokens:
108
+ current_chunk_sentences.append(sent)
109
  current_len += token_len
110
  else:
111
+ # Current chunk is full, finalize it
112
+ if current_chunk_sentences:
113
+ chunks.append(" ".join(current_chunk_sentences))
114
+ # Start a new chunk with the current sentence
115
+ current_chunk_sentences = [sent]
116
  current_len = token_len
117
+
118
+ # Add the last remaining chunk
119
+ if current_chunk_sentences:
120
+ chunks.append(" ".join(current_chunk_sentences))
121
+
122
  return chunks
123
 
124
+ # --- UI ---
125
  st.title("AI Article Detection")
126
+ text = st.text_area("Enter text to classify", height=150, placeholder="Paste your text here...")
127
 
128
+ if st.button("Classify", type="primary"):
129
+ if not text or not text.strip():
130
  st.warning("Please enter some text.")
131
  else:
132
+ with st.spinner("Analyzing... Please wait."):
133
+ try:
134
+ # Split text using the tokenizer reference
135
+ chunks = split_into_chunks(text, tokenizer, max_length=model.config.max_position_embeddings)
136
+ logger.info(f"Split text into {len(chunks)} chunks.")
137
+
138
+ if not chunks:
139
+ st.warning("Could not process the input text (perhaps it's too short or contains only delimiters?).")
140
+ st.stop()
141
+
142
+ # Tokenize chunks and move tensors to the correct device
143
+ inputs = tokenizer(
144
+ chunks,
145
+ return_tensors="pt",
146
+ padding=True, # Pad sequences to the longest in the batch
147
+ truncation=True, # Truncate sequences longer than max_length
148
+ max_length=model.config.max_position_embeddings # Use model's max length
149
+ ).to(device) # Move inputs to the same device as the model
150
+
151
+ # Perform inference
152
+ with torch.no_grad():
153
+ outputs = model(**inputs)
154
+ logits = outputs.logits
155
+ # Ensure probabilities are calculated on CPU if needed for aggregation later
156
+ probs = F.softmax(logits, dim=-1).cpu() # Move probs to CPU
157
+ preds = torch.argmax(probs, dim=-1) # Argmax on CPU probabilities
158
+
159
+ # Process results
160
+ chunk_results = []
161
+ for i, chunk in enumerate(chunks):
162
+ pred_index = preds[i].item() # Get prediction index for this chunk
163
+ chunk_results.append({
164
+ "text": chunk,
165
+ "label": LABELS[pred_index],
166
+ "color": COLORS[pred_index],
167
+ "conf": probs[i, pred_index].item() * 100, # Get confidence for the predicted class
168
+ })
169
+
170
+ # Calculate overall prediction based on average probability across chunks
171
+ if probs.numel() > 0: # Check if probs tensor is not empty
172
+ avg_probs = torch.mean(probs, dim=0) # Average probabilities across the batch dimension
173
+ final_class_index = torch.argmax(avg_probs).item()
174
+ final_label = LABELS[final_class_index]
175
+ final_conf = avg_probs[final_class_index].item() * 100
176
+
177
+ # Display final prediction
178
+ st.subheader("📊 Final Prediction")
179
+ st.markdown(
180
+ f"<div style='background-color:{COLORS[final_class_index]}; padding:1rem; border-radius:0.5rem; border: 1px solid #ccc;'>"
181
+ f"Based on the analysis, the text is most likely: <b>{final_label}</b> (Confidence: {final_conf:.1f}%)</div>",
182
+ unsafe_allow_html=True
183
+ )
184
+ else:
185
+ st.warning("Could not generate predictions for the provided text.")
186
+
187
+
188
+ # Display per-chunk predictions in an expander
189
+ with st.expander("See per-chunk predictions and confidence"):
190
+ if chunk_results:
191
+ for result in chunk_results:
192
+ st.markdown(
193
+ f"<div title='Confidence: {result['conf']:.1f}%' "
194
+ f"style='background-color:{result['color']}; padding:0.75rem; margin-bottom:0.5rem; border-radius:0.5rem; border: 1px solid #ddd;'>"
195
+ f"<i>({result['label']} - {result['conf']:.1f}%)</i><br>{result['text']}</div>",
196
+ unsafe_allow_html=True
197
+ )
198
+ else:
199
+ st.write("No chunk predictions were generated.")
200
+
201
+ except Exception as e:
202
+ st.error(f"An error occurred during analysis: {e}")
203
+ logger.error(f"Analysis failed: {e}", exc_info=True)