Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import shutil
|
| 3 |
import gradio as gr
|
| 4 |
-
from transformers import
|
| 5 |
import pandas as pd
|
| 6 |
import torch
|
| 7 |
import matplotlib.pyplot as plt
|
|
@@ -9,7 +9,7 @@ import seaborn as sns
|
|
| 9 |
import base64
|
| 10 |
|
| 11 |
# Define constants
|
| 12 |
-
MODEL_NAME = "
|
| 13 |
FIGURES_DIR = "./figures"
|
| 14 |
EXAMPLE_DIR = "./example"
|
| 15 |
EXAMPLE_FILE = os.path.join(EXAMPLE_DIR, "titanic.csv")
|
|
@@ -36,7 +36,7 @@ if not os.path.isfile(EXAMPLE_FILE):
|
|
| 36 |
print("Loading model and tokenizer...")
|
| 37 |
try:
|
| 38 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 39 |
-
model =
|
| 40 |
model.to('cpu') # Ensure the model runs on CPU
|
| 41 |
print("Model and tokenizer loaded successfully.")
|
| 42 |
except Exception as e:
|
|
@@ -86,18 +86,15 @@ def generate_summary(prompt):
|
|
| 86 |
|
| 87 |
# Generate response
|
| 88 |
with torch.no_grad():
|
| 89 |
-
|
| 90 |
inputs,
|
| 91 |
max_length=500,
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
temperature=0.7,
|
| 95 |
-
eos_token_id=tokenizer.eos_token_id,
|
| 96 |
-
pad_token_id=tokenizer.eos_token_id
|
| 97 |
)
|
| 98 |
|
| 99 |
-
|
| 100 |
-
return
|
| 101 |
|
| 102 |
def analyze_data(data_file_path):
|
| 103 |
"""Perform data analysis on the uploaded CSV file."""
|
|
@@ -249,3 +246,5 @@ if __name__ == "__main__":
|
|
| 249 |
|
| 250 |
|
| 251 |
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import shutil
|
| 3 |
import gradio as gr
|
| 4 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 5 |
import pandas as pd
|
| 6 |
import torch
|
| 7 |
import matplotlib.pyplot as plt
|
|
|
|
| 9 |
import base64
|
| 10 |
|
| 11 |
# Define constants
|
| 12 |
+
MODEL_NAME = "facebook/bart-large-cnn" # Fine-tuned for summarization
|
| 13 |
FIGURES_DIR = "./figures"
|
| 14 |
EXAMPLE_DIR = "./example"
|
| 15 |
EXAMPLE_FILE = os.path.join(EXAMPLE_DIR, "titanic.csv")
|
|
|
|
| 36 |
print("Loading model and tokenizer...")
|
| 37 |
try:
|
| 38 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 39 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
| 40 |
model.to('cpu') # Ensure the model runs on CPU
|
| 41 |
print("Model and tokenizer loaded successfully.")
|
| 42 |
except Exception as e:
|
|
|
|
| 86 |
|
| 87 |
# Generate response
|
| 88 |
with torch.no_grad():
|
| 89 |
+
summary_ids = model.generate(
|
| 90 |
inputs,
|
| 91 |
max_length=500,
|
| 92 |
+
num_beams=4,
|
| 93 |
+
early_stopping=True
|
|
|
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
|
| 96 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 97 |
+
return summary
|
| 98 |
|
| 99 |
def analyze_data(data_file_path):
|
| 100 |
"""Perform data analysis on the uploaded CSV file."""
|
|
|
|
| 246 |
|
| 247 |
|
| 248 |
|
| 249 |
+
|
| 250 |
+
|