kankur0007 commited on
Commit
5546ab4
·
1 Parent(s): 3a9ed46

Add application file

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from theme_classifier import ThemeClassifier
3
+ from character_network import NamedEntityRecognizer, CharacterNetworkGenerator
4
+ from text_classification import JutsuClassifier
5
+ from character_chatbot import CharacterChatBot
6
+ import os
7
+ from dotenv import load_dotenv
8
+ load_dotenv()
9
+
10
+ def get_themes(theme_list_str,subtitles_path,save_path):
11
+ theme_list = theme_list_str.split(',')
12
+ theme_classifier = ThemeClassifier(theme_list)
13
+ output_df = theme_classifier.get_themes(subtitles_path,save_path)
14
+
15
+ # Remove dialogue from the theme list
16
+ theme_list = [theme for theme in theme_list if theme != 'dialogue']
17
+ output_df = output_df[theme_list]
18
+
19
+ output_df = output_df[theme_list].sum().reset_index()
20
+ output_df.columns = ['Theme','Score']
21
+
22
+ output_chart = gr.BarPlot(
23
+ output_df,
24
+ x="Theme",
25
+ y="Score",
26
+ title="Series Themes",
27
+ tooltip=["Theme","Score"],
28
+ vertical=False,
29
+ width=500,
30
+ height=260
31
+ )
32
+
33
+ return output_chart
34
+
35
+ def get_character_network(subtitles_path,ner_path):
36
+ ner = NamedEntityRecognizer()
37
+ ner_df = ner.get_ners(subtitles_path,ner_path)
38
+
39
+ character_network_generator = CharacterNetworkGenerator()
40
+ relationship_df = character_network_generator.generate_character_network(ner_df)
41
+ html = character_network_generator.draw_network_graph(relationship_df)
42
+
43
+ return html
44
+
45
+ def classify_text(text_classifcation_model,text_classifcation_data_path,text_to_classify):
46
+ jutsu_classifier = JutsuClassifier(model_path = text_classifcation_model,
47
+ data_path = text_classifcation_data_path,
48
+ huggingface_token = os.getenv('huggingface_token'))
49
+
50
+ output = jutsu_classifier.classify_jutsu(text_to_classify)
51
+ output = output[0]
52
+
53
+ return output
54
+
55
+ def chat_with_character_chatbot(message, history):
56
+ character_chatbot = CharacterChatBot("AbdullahTarek/Naruto_Llama-3-8B",
57
+ huggingface_token = os.getenv('huggingface_token')
58
+ )
59
+
60
+ output = character_chatbot.chat(message, history)
61
+ output = output['content'].strip()
62
+ return output
63
+
64
+
65
+ def main():
66
+ with gr.Blocks() as iface:
67
+ # Theme Classification Section
68
+ with gr.Row():
69
+ with gr.Column():
70
+ gr.HTML("<h1>Theme Classification (Zero Shot Claasifiers)</h1>")
71
+ with gr.Row():
72
+ with gr.Column():
73
+ plot = gr.BarPlot()
74
+ with gr.Column():
75
+ theme_list = gr.Textbox(label="Themes")
76
+ subtitles_path = gr.Textbox(label="Subtitles or script Path")
77
+ save_path = gr.Textbox(label="Save Path")
78
+ get_themes_button =gr.Button("Get Themes")
79
+ get_themes_button.click(get_themes, inputs=[theme_list,subtitles_path,save_path], outputs=[plot])
80
+
81
+ # Character Network Section
82
+ with gr.Row():
83
+ with gr.Column():
84
+ gr.HTML("<h1>Character Network (NERs and Graphs)</h1>")
85
+ with gr.Row():
86
+ with gr.Column():
87
+ network_html = gr.HTML()
88
+ with gr.Column():
89
+ subtitles_path = gr.Textbox(label="Subtutles or Script Path")
90
+ ner_path = gr.Textbox(label="NERs save path")
91
+ get_network_graph_button = gr.Button("Get Character Network")
92
+ get_network_graph_button.click(get_character_network, inputs=[subtitles_path,ner_path], outputs=[network_html])
93
+
94
+ # Text Classification with LLMs
95
+ with gr.Row():
96
+ with gr.Column():
97
+ gr.HTML("<h1>Text Classification with LLMs</h1>")
98
+ with gr.Row():
99
+ with gr.Column():
100
+ text_classification_output = gr.Textbox(label="Text Classification Output")
101
+ with gr.Column():
102
+ text_classifcation_model = gr.Textbox(label='Model Path')
103
+ text_classifcation_data_path = gr.Textbox(label='Data Path')
104
+ text_to_classify = gr.Textbox(label='Text input')
105
+ classify_text_button = gr.Button("Clasify Text (Jutsu)")
106
+ classify_text_button.click(classify_text, inputs=[text_classifcation_model,text_classifcation_data_path,text_to_classify], outputs=[text_classification_output])
107
+
108
+ # Character Chatbot Section
109
+ with gr.Row():
110
+ with gr.Column():
111
+ gr.HTML("<h1>Character Chatbot</h1>")
112
+ gr.ChatInterface(chat_with_character_chatbot)
113
+
114
+ iface.launch(share=True)
115
+
116
+
117
+ if __name__ == '__main__':
118
+ main()