Shuja007 commited on
Commit
9a5c0c6
·
verified ·
1 Parent(s): fcd4967

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -35
app.py CHANGED
@@ -1,36 +1,48 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- import os
4
-
5
- # Set the path to your local model directory
6
- model_path = "./bart_samsum"
7
-
8
- # Check if the model path exists
9
- if not os.path.exists(model_path):
10
- st.error(f"The path {model_path} does not exist. Please check the path.")
11
- else:
12
- # Load the tokenizer and model from the local directory
13
- tokenizer = AutoTokenizer.from_pretrained(model_path)
14
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
15
-
16
- # Streamlit app UI
17
- st.title("BART Summarization Model")
18
-
19
- input_text = st.text_area("Input Text", "Enter text here...")
20
-
21
- if st.button("Generate Summary"):
22
- if not input_text.strip():
23
- st.warning("Please enter some text to summarize.")
24
- else:
25
- # Tokenize and generate summary
26
- inputs = tokenizer(input_text, return_tensors="pt")
27
- summary_ids = model.generate(inputs["input_ids"], max_length=150, num_beams=4, early_stopping=True)
28
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
29
-
30
- # Display the summary
31
- st.subheader("Generated Summary")
32
- st.write(summary)
33
-
34
- # Optionally, you can add a section to display model information or statistics
35
- st.sidebar.title("Model Information")
36
- st.sidebar.write("This app uses a fine-tuned BART model for summarization.")
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ from google_drive_downloader import GoogleDriveDownloader as gdd
5
+
6
+ # Set the title of the Streamlit app
7
+ st.title("Text Classification with Hugging Face Transformers")
8
+
9
+ # Function to download the model from Google Drive
10
+ def download_model_from_drive(file_id, dest_path):
11
+ gdd.download_file_from_google_drive(file_id=file_id, dest_path=dest_path, unzip=False)
12
+
13
+ # Download the model files
14
+ with st.spinner("Downloading model..."):
15
+ download_model_from_drive('1-V2bEtPR9Y3iBXK9zOR-qM5y9hKiQUnF', 'model/model.safetensors')
16
+ download_model_from_drive('1-T2etSP_k_3j5LzunWq8viKGQCQ5RMr_', 'model/config.json')
17
+ download_model_from_drive('1-cRYNPWqlNNGRxeztympRRfVuy3hWuMY', 'model/tokenizer.json')
18
+ download_model_from_drive('1-t9AhomeH7YIIpAqCGTok8wjvl0tml0F', 'model/vocab.json')
19
+ download_model_from_drive('1-l77_KEdK7GBFjMX_6UXGE-ZTGDraaDm', 'model/merges.txt')
20
+
21
+ # Load the model and tokenizer
22
+ @st.cache(allow_output_mutation=True)
23
+ def load_model_and_tokenizer():
24
+ tokenizer = AutoTokenizer.from_pretrained('model')
25
+ # For Safetensors, you might need a custom loading mechanism
26
+ model = AutoModelForSequenceClassification.from_pretrained('model', use_safetensors=True) # Adjust if necessary
27
+ return tokenizer, model
28
+
29
+ tokenizer, model = load_model_and_tokenizer()
30
+
31
+ # Input text from user
32
+ input_text = st.text_area("Enter the text to classify:")
33
+
34
+ if st.button("Classify"):
35
+ if input_text:
36
+ # Tokenize the input text
37
+ inputs = tokenizer(input_text, return_tensors="pt")
38
+
39
+ # Perform classification
40
+ with torch.no_grad():
41
+ outputs = model(**inputs)
42
+
43
+ # Get the predicted class
44
+ predicted_class = torch.argmax(outputs.logits, dim=1).item()
45
+
46
+ st.write(f"Predicted Class: {predicted_class}")
47
+ else:
48
+ st.write("Please enter some text to classify.")