WhoLetMeCook commited on
Commit
0cf1faf
·
verified ·
1 Parent(s): 4c41360

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -94
app.py CHANGED
@@ -1,100 +1,26 @@
1
- import streamlit as st
2
- import numpy as np
3
- from datasets import load_dataset, Dataset
4
- from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
5
- from datasets import load_metric
6
  import torch
7
 
8
- # Load datasets (IMDB and SST2) and combine them
9
- @st.cache_resource
10
- def load_datasets():
11
- imdb = load_dataset('imdb', split='train[:5000]')
12
- sst2 = load_dataset('glue', 'sst2', split='train[:5000]')
13
-
14
- # Combine datasets into a single list
15
- train_list = [{'text': example['text'], 'label': example['label']} for example in imdb] + [{'text': example['sentence'], 'label': example['label']} for example in sst2]
16
- full_data = Dataset.from_list(train_list)
17
-
18
- # Split the dataset into train/validation/test
19
- train_data = full_data.train_test_split(test_size=0.2, seed=42)
20
- train_data = train_data['train'].train_test_split(test_size=0.25, seed=42) # 60% train, 20% validation, 20% test
21
- return train_data['train'], train_data['test']
22
-
23
- train_dataset, val_dataset = load_datasets()
24
-
25
- # Load the tokenizer and model
26
- @st.cache_resource
27
- def load_tokenizer_model():
28
- tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
29
- model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
30
- return tokenizer, model
31
-
32
- tokenizer, model = load_tokenizer_model()
33
-
34
- # Preprocess function for tokenization
35
- def preprocess_function(examples):
36
- return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
37
-
38
- # Tokenize datasets
39
- tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
40
- tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True)
41
-
42
- # Define the training arguments
43
- training_args = TrainingArguments(
44
- output_dir='./results',
45
- evaluation_strategy='epoch',
46
- learning_rate=2e-5,
47
- per_device_train_batch_size=16,
48
- per_device_eval_batch_size=16,
49
- num_train_epochs=3,
50
- weight_decay=0.01,
51
- logging_dir='./logs',
52
- )
53
-
54
- # Load accuracy metric
55
- metric = load_metric('accuracy')
56
-
57
- # Function to compute metrics
58
- def compute_metrics(eval_pred):
59
- logits, labels = eval_pred
60
- predictions = np.argmax(logits, axis=-1)
61
- return metric.compute(predictions=predictions, references=labels)
62
-
63
- # Initialize the trainer
64
- trainer = Trainer(
65
- model=model,
66
- args=training_args,
67
- train_dataset=tokenized_train_dataset,
68
- eval_dataset=tokenized_val_dataset,
69
- compute_metrics=compute_metrics,
70
- )
71
-
72
- # Streamlit UI
73
- st.title("DistilBERT Sentiment Training and Inference")
74
-
75
- # Button to start training
76
- if st.button("Train the Model"):
77
- st.write("Training the model... This will take some time.")
78
- trainer.train()
79
- st.write("Model training complete!")
80
-
81
- # User input for inference
82
- st.write("Once the model is trained, you can enter a sentence for sentiment analysis:")
83
- user_input = st.text_area("Enter a sentence:")
84
 
85
  # Function to make predictions
86
- def predict_sentiment(text):
87
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
88
  with torch.no_grad():
89
  outputs = model(**inputs)
90
- logits = outputs.logits
91
- prediction = torch.argmax(logits, dim=-1).item()
92
- return "Positive" if prediction == 1 else "Negative"
93
-
94
- # Button to generate predictions after training
95
- if st.button("Analyze Sentiment"):
96
- if user_input.strip():
97
- result = predict_sentiment(user_input)
98
- st.write(f"Predicted Sentiment: **{result}**")
99
- else:
100
- st.write("Please enter a sentence.")
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
3
  import torch
4
 
5
+ # Load the model and tokenizer
6
+ model_name = "WhoLetMeCook/ChefBERT"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Function to make predictions
11
+ def predict_emotion(text):
12
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
13
  with torch.no_grad():
14
  outputs = model(**inputs)
15
+ prediction = torch.argmax(outputs.logits, dim=-1).item()
16
+ return "Positive Emotion" if prediction == 1 else "Negative Emotion"
17
+
18
+ # Create the Gradio interface
19
+ iface = gr.Interface(fn=predict_emotion,
20
+ inputs="text",
21
+ outputs="text",
22
+ title="ChefBERT Emotion Classifier",
23
+ description="Enter a sentence and ChefBERT will predict whether the emotion is positive (1) or negative (0).")
24
+
25
+ # Launch the interface
26
+ iface.launch()