mohAhmad commited on
Commit
75e78ee
·
verified ·
1 Parent(s): 9a97e6a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from transformers import TFTForConditionalGeneration, TFTTokenizer
5
+ import torch
6
+
7
+ # Function to load data and model
8
+ @st.cache_data
9
+ def load_data(file):
10
+ data = pd.read_csv(file)
11
+ return data
12
+
13
+ # Function to predict using the model
14
+ def predict_earthquake_positions(data, model, tokenizer):
15
+ inputs = tokenizer(data.to_dict(orient='list'), return_tensors="pt", padding=True, truncation=True)
16
+ with torch.no_grad():
17
+ outputs = model.generate(inputs['input_ids'], num_beams=5, early_stopping=True)
18
+ return outputs
19
+
20
+ # Load Hugging Face model and tokenizer
21
+ @st.cache_resource
22
+ def load_model():
23
+ model = TFTForConditionalGeneration.from_pretrained("huggingface/tft")
24
+ tokenizer = TFTTokenizer.from_pretrained("huggingface/tft")
25
+ return model, tokenizer
26
+
27
+ # Streamlit App
28
+ st.title('Earthquake Detection Prediction App')
29
+
30
+ uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
31
+
32
+ if uploaded_file is not None:
33
+ data = load_data(uploaded_file)
34
+
35
+ # Display the data
36
+ st.subheader("Uploaded Data")
37
+ st.write(data)
38
+
39
+ # Load the model and tokenizer
40
+ model, tokenizer = load_model()
41
+
42
+ # Make predictions
43
+ st.subheader("Predictions")
44
+ predictions = predict_earthquake_positions(data, model, tokenizer)
45
+
46
+ # Plotting the predictions
47
+ st.subheader("Earthquake Prediction Plot")
48
+ fig, ax = plt.subplots()
49
+ ax.plot(data['x'], data['prediction'], label="Predicted Earthquake Position", color='green')
50
+ ax.axvline(x=predictions, color='red', linestyle='--', label='Predicted Earthquake Start')
51
+ ax.legend()
52
+ st.pyplot(fig)