nnsohamnn commited on
Commit
633e441
·
verified ·
1 Parent(s): 6e6f523

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -81
app.py CHANGED
@@ -1,89 +1,118 @@
1
  import gradio as gr
2
- import numpy as np
3
  import tensorflow as tf
 
4
  import pickle
5
  from tensorflow.keras.preprocessing.sequence import pad_sequences
6
-
7
- # === Load Classifier ===
8
- classifier_model = tf.keras.models.load_model("classifier_model.keras")
9
- with open("classifier_tokenizer.pkl", "rb") as f:
10
- classifier_tokenizer = pickle.load(f)
11
-
12
- # === Load Text Generator ===
13
- textgen_model = tf.keras.models.load_model("textgen_model.keras")
14
- with open("textgen_tokenizer.pkl", "rb") as f:
15
- textgen_tokenizer = pickle.load(f)
16
-
17
- # === Settings ===
18
- CLASS_LABELS = ['Science', 'Math', 'History'] # Replace with actual labels
19
- MAX_LEN_CLASSIFIER = 100
20
- MAX_LEN_TEXTGEN = 50
21
-
22
-
23
- # === Classifier Inference ===
24
- def classify_text(text):
25
- seq = classifier_tokenizer.texts_to_sequences([text])
26
- padded = pad_sequences(seq, maxlen=MAX_LEN_CLASSIFIER)
27
- preds = classifier_model.predict(padded, verbose=0)[0]
28
- return {CLASS_LABELS[i]: float(preds[i]) for i in range(len(CLASS_LABELS))}
29
-
30
-
31
- # === Text Generation with Top-k Sampling ===
32
- def generate_text(seed_text, next_words=15, k=10, temperature=0.9):
33
- recent_words = set()
34
- for _ in range(next_words):
35
- token_list = textgen_tokenizer.texts_to_sequences([seed_text])[0]
36
- token_list = pad_sequences([token_list], maxlen=MAX_LEN_TEXTGEN-1, padding='pre')
37
- predicted_probs = textgen_model.predict(token_list, verbose=0)[0]
38
-
39
- scaled_probs = np.log(predicted_probs + 1e-10) / temperature
40
- exp_probs = np.exp(scaled_probs)
41
- normalized_probs = exp_probs / np.sum(exp_probs)
42
-
43
- top_k_indices = np.argsort(normalized_probs)[-k:]
44
- top_k_probs = normalized_probs[top_k_indices]
45
- top_k_probs = top_k_probs / np.sum(top_k_probs)
46
-
47
- predicted_word_index = np.random.choice(top_k_indices, p=top_k_probs)
48
-
49
- # Avoid repetition
50
- attempts = 0
51
- while predicted_word_index in recent_words and attempts < 5:
52
- predicted_word_index = np.random.choice(top_k_indices, p=top_k_probs)
53
- attempts += 1
54
-
55
- recent_words.add(predicted_word_index)
56
- if len(recent_words) > 10:
57
- recent_words.pop()
58
-
 
 
 
 
 
59
  output_word = ""
60
- for word, index in textgen_tokenizer.word_index.items():
61
- if index == predicted_word_index:
62
  output_word = word
63
  break
64
-
65
- seed_text += " " + output_word
66
-
67
- return seed_text
68
-
69
-
70
- # === Gradio UI ===
71
- with gr.Blocks() as demo:
72
- gr.Markdown("## Dual Model: Text Classifier + Text Generator")
73
-
74
- with gr.Tab("Text Classification"):
75
- input_text = gr.Textbox(label="Enter Text", lines=4)
76
- classify_btn = gr.Button("Classify")
77
- output_label = gr.Label()
78
- classify_btn.click(fn=classify_text, inputs=input_text, outputs=output_label)
79
-
80
- with gr.Tab("Text Generation"):
81
- seed_text = gr.Textbox(label="Seed Text", lines=4)
82
- num_words = gr.Slider(5, 100, value=30, label="Words to Generate")
83
- k_top = gr.Slider(1, 50, value=10, step=1, label="Top-k Sampling")
84
- temp = gr.Slider(0.5, 1.5, value=0.9, label="Temperature")
85
- generate_btn = gr.Button("Generate")
86
- gen_output = gr.Textbox(label="Generated Text", lines=6)
87
- generate_btn.click(fn=generate_text, inputs=[seed_text, num_words, k_top, temp], outputs=gen_output)
88
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  demo.launch()
 
1
  import gradio as gr
 
2
  import tensorflow as tf
3
+ import numpy as np
4
  import pickle
5
  from tensorflow.keras.preprocessing.sequence import pad_sequences
6
+ import re
7
+
8
+ # Load models and tokenizers
9
+ def load_models():
10
+ # Load classifier model and tokenizer
11
+ classifier_model = tf.keras.models.load_model('classifier_model.keras')
12
+ with open('classifier_tokenizer.pkl', 'rb') as handle:
13
+ classifier_tokenizer = pickle.load(handle)
14
+
15
+ # Load text generator model and tokenizer
16
+ textgen_model = tf.keras.models.load_model('textgen_model.keras')
17
+ with open('textgen_tokenizer.pkl', 'rb') as handle:
18
+ textgen_tokenizer = pickle.load(handle)
19
+
20
+ return classifier_model, classifier_tokenizer, textgen_model, textgen_tokenizer
21
+
22
+ # Text cleaning function
23
+ def clean_text(text):
24
+ text = re.sub(r'[^\w\s.,!?]', '', text)
25
+ text = re.sub(r'\b\d+\b', '', text)
26
+ text = text.replace('co2', 'carbon dioxide')
27
+ text = text.lower()
28
+ text = ' '.join(text.split())
29
+ return text
30
+
31
+ # Classification function
32
+ def classify_text(text, model, tokenizer):
33
+ cleaned_text = clean_text(text)
34
+ sequence = tokenizer.texts_to_sequences([cleaned_text])
35
+ padded = pad_sequences(sequence, maxlen=255, padding='pre')
36
+ prediction = model.predict(padded)[0]
37
+
38
+ # Get the highest probability class
39
+ class_idx = np.argmax(prediction)
40
+ classes = ['Science', 'Maths', 'History']
41
+ confidence = prediction[class_idx] * 100
42
+
43
+ return classes[class_idx], confidence
44
+
45
+ # Text generation function
46
+ def generate_text(prompt, model, tokenizer, max_length=50, temperature=0.7):
47
+ cleaned_prompt = clean_text(prompt)
48
+ input_text = cleaned_prompt
49
+
50
+ for _ in range(max_length):
51
+ token_list = tokenizer.texts_to_sequences([input_text])[0]
52
+ token_list = pad_sequences([token_list], maxlen=255, padding='pre')
53
+
54
+ predicted = model.predict(token_list, verbose=0)[0]
55
+
56
+ # Apply temperature
57
+ predicted = np.log(predicted) / temperature
58
+ exp_preds = np.exp(predicted)
59
+ predicted = exp_preds / np.sum(exp_preds)
60
+
61
+ # Sample from the distribution
62
+ predicted_index = np.random.choice(len(predicted), p=predicted)
63
+
64
  output_word = ""
65
+ for word, index in tokenizer.word_index.items():
66
+ if index == predicted_index:
67
  output_word = word
68
  break
69
+
70
+ if output_word == "":
71
+ break
72
+
73
+ input_text += " " + output_word
74
+
75
+ return input_text
76
+
77
+ # Load models
78
+ classifier_model, classifier_tokenizer, textgen_model, textgen_tokenizer = load_models()
79
+
80
+ # Create Gradio interface
81
+ def classify_interface(text):
82
+ subject, confidence = classify_text(text, classifier_model, classifier_tokenizer)
83
+ return f"Subject: {subject} (Confidence: {confidence:.2f}%)"
84
+
85
+ def generate_interface(prompt, length=50, temp=0.7):
86
+ return generate_text(prompt, textgen_model, textgen_tokenizer, max_length=int(length), temperature=float(temp))
87
+
88
+ # Create tabs for different functionalities
89
+ with gr.Blocks(title="Science Text Analyzer") as demo:
90
+ gr.Markdown("# Science Text Analyzer")
91
+
92
+ with gr.Tab("Classify Text"):
93
+ gr.Markdown("## Classify scientific text into Physics, Chemistry, or Biology")
94
+ with gr.Row():
95
+ with gr.Column():
96
+ text_input = gr.Textbox(label="Enter scientific text", lines=5)
97
+ classify_button = gr.Button("Classify")
98
+ with gr.Column():
99
+ output = gr.Textbox(label="Classification Result")
100
+ classify_button.click(fn=classify_interface, inputs=text_input, outputs=output)
101
+
102
+ with gr.Tab("Generate Text"):
103
+ gr.Markdown("## Generate scientific text based on a prompt")
104
+ with gr.Row():
105
+ with gr.Column():
106
+ prompt_input = gr.Textbox(label="Enter a prompt", lines=3)
107
+ length_slider = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Maximum Length")
108
+ temp_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature (Creativity)")
109
+ generate_button = gr.Button("Generate")
110
+ with gr.Column():
111
+ generated_output = gr.Textbox(label="Generated Text", lines=8)
112
+ generate_button.click(fn=generate_interface, inputs=[prompt_input, length_slider, temp_slider], outputs=generated_output)
113
+
114
+ gr.Markdown("### About")
115
+ gr.Markdown("This app uses deep learning models trained on scientific texts to classify and generate content related to Physics, Chemistry, and Biology.")
116
+
117
+ # Launch the app
118
  demo.launch()