File size: 5,111 Bytes
4b43eb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from flask import Flask, request, jsonify
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, pipeline, AutoTokenizer, AutoModelForTokenClassification
import re
from flask_cors import CORS


app = Flask(__name__)
CORS(app)




# Load chatbot model
model_name = "facebook/blenderbot-400M-distill"
tokenizer = BlenderbotTokenizer.from_pretrained(model_name)
model = BlenderbotForConditionalGeneration.from_pretrained(model_name)

# Load POS tagging pipeline
pos_pipe = pipeline("token-classification", model="TweebankNLP/bertweet-tb2-pos-tagging")

# Load NER model
model_checkpoint = "huggingface-course/bert-finetuned-ner"
ner_model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
ner_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
token_classifier = pipeline(
    "token-classification", model=ner_model, aggregation_strategy="simple", tokenizer=ner_tokenizer,
)

# Function to clean messages
def clean_message(text):
    # Remove emojis and special characters (except spaces and letters)
    text = re.sub(r'[^\w\s]', '', text)

    # Reduce repeated letters only if they appear more than twice at the end
    text = re.sub(r'(\w*?)(\w)\2{2,}\b', r'\1\2', text)

    # Perform POS tagging
    pos_tags = pos_pipe(text)

    # Convert words to title case selectively
    words = text.split()
    cleaned_words = []

    for i, word in enumerate(words):
        tag = next((tag_info["entity"] for tag_info in pos_tags if tag_info["word"] == word), None)

        if tag in ["ADJ", "ADP"]:  # Keep ADJ and ADP words lowercase
            cleaned_words.append(word.lower())
        else:  # Title case for other words
            cleaned_words.append(word.title())

    # Remove single-letter words (except 'I' or 'A' if needed)
    cleaned_words = [word for word in cleaned_words if len(word) > 1]

    return " ".join(cleaned_words)

# Function to extract named entities from a single message
def extract_entities(text, message_index, existing_entities=set(), threshold=0.85):
    entities_dict = {"PER": [], "ORG": [], "LOC": [], "MISC": []}
    seen_words = set(existing_entities)  # Initialize the set of previously noted entities

    results = token_classifier(text)

    for entity in results:
        word = entity["word"]
        entity_type = entity["entity_group"]
        score = entity["score"]

        # Ignore low-confidence entities
        if score < threshold:
            continue

        # Ignore subword tokens (split words like "##word")
        if word.startswith("##"):
            continue

        # Ignore short words (e.g., single letters)
        if len(word) == 1:
            continue

        # Keep multi-word locations intact
        if entity_type == "LOC":
            processed_words = [word]
        else:
            processed_words = word.split()

        for single_word in processed_words:
            # Check if the word has been already noted
            if single_word not in seen_words:
                seen_words.add(single_word)
                # Add new word to the respective entity list
                if entity_type in entities_dict:
                    entities_dict[entity_type].append({
                        "index": message_index,
                        "word": single_word,
                        "substring": (text.find(single_word), text.find(single_word) + len(single_word))
                    })

    return entities_dict



@app.route("/")
def home():
    return "Hello, World!"


@app.route("/api/home", methods=['POST','GET'])
def receive_message():
    data = request.get_json()
    message_index = data.get("index")
    message = data.get("message", "")

    print(f"Received message at index {message_index}: {message}")

    # Clean user message
    cleaned_message = clean_message(message)
    print("Cleaned Message:", cleaned_message)

    # Extract named entities from user message
    user_entities = extract_entities(cleaned_message, message_index)
    print("Extracted Entities from User's Message:", user_entities)

    # Generate chatbot response
    inputs = tokenizer(cleaned_message, return_tensors="pt")
    reply_ids = model.generate(**inputs)
    bot_response = tokenizer.decode(reply_ids[0], skip_special_tokens=True)

    print(f"Chatbot Response: {bot_response}")

    # The bot's response index will be the user message index + 1
    bot_index = message_index + 1

    # Extract named entities from chatbot response (bot index)
    bot_entities = extract_entities(bot_response, bot_index)
    print("Extracted Entities from Chatbot's Response:", bot_entities)

    return jsonify({
        'response': bot_response,
        'person_user': user_entities.get("PER", []),
        'location_user': user_entities.get("LOC", []),
        'person_bot': bot_entities.get("PER", []),
        'location_bot': bot_entities.get("LOC", [])
    })

if __name__ == "__main__":
    app.run(host="0.0.0.0", debug=True)