Spaces:
Sleeping
Sleeping
| """ | |
| File managing the text analysis functionality for the application. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from datasets import Dataset | |
| from models import load_model, device | |
| def create_chunked_dataset(dataset, tokenizer, max_length=512, stride=256): | |
| """ | |
| Crée un nouveau dataset avec des chunks à partir du dataset original | |
| """ | |
| all_chunks = { | |
| 'input_ids': [], | |
| 'attention_mask': [], | |
| 'chunk_id': [], | |
| 'example_id': [], | |
| 'labels': [] | |
| } | |
| for idx, example in enumerate(dataset): | |
| text = example['text'] | |
| label = int(example['label']) | |
| tokenized = tokenizer(text, truncation=False, padding=False) | |
| input_ids = tokenized['input_ids'] | |
| if len(input_ids) <= max_length: | |
| tokenized = tokenizer( | |
| text, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=max_length | |
| ) | |
| all_chunks['input_ids'].append(tokenized['input_ids']) | |
| all_chunks['attention_mask'].append(tokenized['attention_mask']) | |
| all_chunks['chunk_id'].append(0) | |
| all_chunks['example_id'].append(idx) | |
| all_chunks['labels'].append(label) | |
| else: | |
| chunk_id = 0 | |
| for i in range(0, len(input_ids), stride): | |
| chunk = input_ids[i:i + max_length] | |
| if len(chunk) < max_length // 2: | |
| continue | |
| # Padding to max_length | |
| attention_mask = [1] * len(chunk) | |
| if len(chunk) < max_length: | |
| padding_length = max_length - len(chunk) | |
| chunk = chunk + [tokenizer.pad_token_id] * padding_length | |
| attention_mask = attention_mask + [0] * padding_length | |
| all_chunks['input_ids'].append(chunk) | |
| all_chunks['attention_mask'].append(attention_mask) | |
| all_chunks['chunk_id'].append(chunk_id) | |
| all_chunks['example_id'].append(idx) | |
| all_chunks['labels'].append(label) | |
| chunk_id += 1 | |
| return Dataset.from_dict(all_chunks) | |
| def analyze_text(text: str, model_name: str): | |
| """ | |
| Analyze the text for bias or neutrality using a selected classification model. | |
| Args: | |
| text (str) : Text to analyze. | |
| model_name (str) : Name of the model to use for analysis. | |
| Returns: | |
| tuple (confidence_map, message) : Confidence map and analysis message. | |
| """ | |
| if not text.strip(): | |
| return {"Empty text": 1.0}, "Please enter text to analyze." | |
| try: | |
| print("[Checkpoint] Starting classification...") | |
| model, tokenizer = load_model(model_name) | |
| mini_dataset = Dataset.from_dict({"text": [text], "label": [0]}) | |
| chunked_dataset = create_chunked_dataset(mini_dataset, tokenizer) | |
| print("[Checkpoint] Tokenization complete. Running model...") | |
| model.eval() | |
| model.to(device) | |
| all_logits = [] | |
| for i in range(len(chunked_dataset)): | |
| chunk = chunked_dataset[i] | |
| inputs = { | |
| 'input_ids': torch.tensor([chunk['input_ids']]).to(device), | |
| 'attention_mask': torch.tensor([chunk['attention_mask']]).to(device), | |
| } | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| all_logits.append(outputs.logits[0].cpu().squeeze(0)) | |
| stacked_logits = torch.stack(all_logits) | |
| averaged_logits = torch.mean(stacked_logits, dim=0) | |
| probs = F.softmax(averaged_logits, dim=0) | |
| predicted_class = torch.argmax(averaged_logits).item() | |
| confidence = probs[predicted_class].item() | |
| status = "neutral" if predicted_class == 1 else "biased" | |
| tag = "✅" if predicted_class == 1 else "⚠️" | |
| message = f"{tag} The text is classified as {status} with a confidence of {confidence:.2%}." | |
| confidence_map = {"Neutral": probs[1].item(), "Biased": probs[0].item()} | |
| print(f"[Checkpoint] Classification complete. Predicted answer: {status}") | |
| return confidence_map, message | |
| except ValueError as e: | |
| return {"Error": 1.0}, f"Configuration error: {str(e)}" | |
| except RuntimeError as e: | |
| return {"Error": 1.0}, f"Model error: {str(e)}" | |
| except Exception as e: | |
| return {"Error": 1.0}, f"Error analyzing text: {str(e)}" |