Update app.py
Browse files
app.py
CHANGED
|
@@ -9,6 +9,18 @@ from sklearn.preprocessing import LabelEncoder
|
|
| 9 |
import requests
|
| 10 |
from io import BytesIO
|
| 11 |
import gdown
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# --- Set page configuration ---
|
| 14 |
st.set_page_config(
|
|
@@ -411,6 +423,39 @@ def style_metric_container(label, value):
|
|
| 411 |
</div>
|
| 412 |
""", unsafe_allow_html=True)
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
def search_dataset(dataset, make, model=None):
|
| 415 |
"""
|
| 416 |
Search the dataset for the specified make and model. If no model is provided,
|
|
@@ -662,17 +707,21 @@ def predict_with_ranges(inputs, model, label_encoders):
|
|
| 662 |
'max_price': max_price
|
| 663 |
}
|
| 664 |
# --- Main Application ---
|
| 665 |
-
def main(
|
| 666 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 667 |
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
""", unsafe_allow_html=True)
|
| 675 |
-
|
| 676 |
inputs, predict_button = create_prediction_interface()
|
| 677 |
|
| 678 |
if predict_button:
|
|
@@ -685,24 +734,75 @@ def main(model, label_encoders, dataset):
|
|
| 685 |
- **Model Prediction**: ${prediction_results['predicted_price']:,.2f}
|
| 686 |
""")
|
| 687 |
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 692 |
else:
|
| 693 |
-
|
| 694 |
|
| 695 |
-
|
| 696 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 697 |
|
| 698 |
if __name__ == "__main__":
|
| 699 |
-
|
| 700 |
-
# Load data and model
|
| 701 |
-
original_data = load_datasets()
|
| 702 |
-
model, label_encoders = load_model_and_encodings()
|
| 703 |
-
|
| 704 |
-
# Call the main function
|
| 705 |
-
main(model, label_encoders, original_data)
|
| 706 |
-
except Exception as e:
|
| 707 |
-
st.error(f"Error loading data or models: {str(e)}")
|
| 708 |
-
st.stop()
|
|
|
|
| 9 |
import requests
|
| 10 |
from io import BytesIO
|
| 11 |
import gdown
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
| 14 |
+
import torch
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
|
| 17 |
+
# --- Set page configuration ---
|
| 18 |
+
st.set_page_config(
|
| 19 |
+
page_title="Car Analysis Tool",
|
| 20 |
+
page_icon="🚗",
|
| 21 |
+
layout="wide",
|
| 22 |
+
initial_sidebar_state="expanded"
|
| 23 |
+
)
|
| 24 |
|
| 25 |
# --- Set page configuration ---
|
| 26 |
st.set_page_config(
|
|
|
|
| 423 |
</div>
|
| 424 |
""", unsafe_allow_html=True)
|
| 425 |
|
| 426 |
+
def classify_image(image):
|
| 427 |
+
try:
|
| 428 |
+
model_name = "dima806/car_models_image_detection"
|
| 429 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
| 430 |
+
model = AutoModelForImageClassification.from_pretrained(model_name)
|
| 431 |
+
|
| 432 |
+
inputs = feature_extractor(images=image, return_tensors="pt")
|
| 433 |
+
|
| 434 |
+
with torch.no_grad():
|
| 435 |
+
outputs = model(**inputs)
|
| 436 |
+
|
| 437 |
+
logits = outputs.logits
|
| 438 |
+
predicted_class_idx = logits.argmax(-1).item()
|
| 439 |
+
predicted_class_label = model.config.id2label[predicted_class_idx]
|
| 440 |
+
score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
|
| 441 |
+
|
| 442 |
+
return [{'label': predicted_class_label, 'score': score}]
|
| 443 |
+
except Exception as e:
|
| 444 |
+
st.error(f"Classification error: {e}")
|
| 445 |
+
return None
|
| 446 |
+
|
| 447 |
+
def get_car_overview(brand, model, year):
|
| 448 |
+
try:
|
| 449 |
+
prompt = f"Provide an overview of the following car:\nYear: {year}\nMake: {brand}\nModel: {model}\n"
|
| 450 |
+
response = openai.ChatCompletion.create(
|
| 451 |
+
model="gpt-3.5-turbo",
|
| 452 |
+
messages=[{"role": "user", "content": prompt}]
|
| 453 |
+
)
|
| 454 |
+
return response.choices[0].message['content']
|
| 455 |
+
except Exception as e:
|
| 456 |
+
st.error(f"Error getting car overview: {str(e)}")
|
| 457 |
+
return None
|
| 458 |
+
|
| 459 |
def search_dataset(dataset, make, model=None):
|
| 460 |
"""
|
| 461 |
Search the dataset for the specified make and model. If no model is provided,
|
|
|
|
| 707 |
'max_price': max_price
|
| 708 |
}
|
| 709 |
# --- Main Application ---
|
| 710 |
+
def main():
|
| 711 |
+
# Load necessary data and models
|
| 712 |
+
try:
|
| 713 |
+
original_data = load_datasets()
|
| 714 |
+
model, label_encoders = load_model_and_encodings()
|
| 715 |
+
except Exception as e:
|
| 716 |
+
st.error(f"Error loading data or models: {str(e)}")
|
| 717 |
+
st.stop()
|
| 718 |
|
| 719 |
+
# Create tabs
|
| 720 |
+
tab1, tab2 = st.tabs(["Price Prediction", "Image Analysis"])
|
| 721 |
+
|
| 722 |
+
with tab1:
|
| 723 |
+
st.title("Car Price Prediction")
|
| 724 |
+
# [Previous prediction interface code]
|
|
|
|
|
|
|
| 725 |
inputs, predict_button = create_prediction_interface()
|
| 726 |
|
| 727 |
if predict_button:
|
|
|
|
| 734 |
- **Model Prediction**: ${prediction_results['predicted_price']:,.2f}
|
| 735 |
""")
|
| 736 |
|
| 737 |
+
# Generate and display the graph
|
| 738 |
+
fig = create_market_trends_plot_with_model(model, inputs["make"], inputs, label_encoders)
|
| 739 |
+
if fig:
|
| 740 |
+
st.pyplot(fig)
|
| 741 |
+
|
| 742 |
+
with tab2:
|
| 743 |
+
st.title("Car Image Analysis")
|
| 744 |
+
|
| 745 |
+
# File uploader and camera input
|
| 746 |
+
uploaded_file = st.file_uploader("Choose a car image", type=["jpg", "jpeg", "png"])
|
| 747 |
+
camera_image = st.camera_input("Or take a picture of the car")
|
| 748 |
+
|
| 749 |
+
# Process the image
|
| 750 |
+
if uploaded_file is not None:
|
| 751 |
+
image = Image.open(uploaded_file)
|
| 752 |
+
elif camera_image is not None:
|
| 753 |
+
image = Image.open(camera_image)
|
| 754 |
else:
|
| 755 |
+
image = None
|
| 756 |
|
| 757 |
+
if image is not None:
|
| 758 |
+
st.image(image, caption='Uploaded Image', use_container_width=True)
|
| 759 |
+
|
| 760 |
+
# Classify the image
|
| 761 |
+
with st.spinner('Analyzing image...'):
|
| 762 |
+
car_classifications = classify_image(image)
|
| 763 |
+
|
| 764 |
+
if car_classifications:
|
| 765 |
+
top_prediction = car_classifications[0]['label']
|
| 766 |
+
make_name, model_name = top_prediction.split(' ', 1)
|
| 767 |
+
current_year = datetime.now().year
|
| 768 |
+
|
| 769 |
+
# Display results
|
| 770 |
+
col1, col2 = st.columns(2)
|
| 771 |
+
col1.metric("Identified Make", make_name)
|
| 772 |
+
col2.metric("Identified Model", model_name)
|
| 773 |
+
|
| 774 |
+
# Get car overview
|
| 775 |
+
overview = get_car_overview(make_name, model_name, current_year)
|
| 776 |
+
if overview:
|
| 777 |
+
st.subheader("Car Overview")
|
| 778 |
+
st.write(overview)
|
| 779 |
+
|
| 780 |
+
# Use the prediction model with the identified car
|
| 781 |
+
st.subheader("Price Analysis for Identified Car")
|
| 782 |
+
auto_inputs = {
|
| 783 |
+
'year': current_year,
|
| 784 |
+
'make': make_name.lower(),
|
| 785 |
+
'model': model_name.lower(),
|
| 786 |
+
'condition': 'good', # Default values
|
| 787 |
+
'fuel': 'gas',
|
| 788 |
+
'odometer': 0,
|
| 789 |
+
'title_status': 'clean',
|
| 790 |
+
'transmission': 'automatic',
|
| 791 |
+
'drive': 'fwd',
|
| 792 |
+
'size': 'mid-size',
|
| 793 |
+
'type': 'sedan',
|
| 794 |
+
'paint_color': 'white'
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
# Get prediction for the identified car
|
| 798 |
+
prediction_results = predict_with_ranges(auto_inputs, model, label_encoders)
|
| 799 |
+
|
| 800 |
+
st.markdown(f"""
|
| 801 |
+
### Estimated Price Range
|
| 802 |
+
- **Minimum**: ${prediction_results['min_price']:,.2f}
|
| 803 |
+
- **Maximum**: ${prediction_results['max_price']:,.2f}
|
| 804 |
+
- **Predicted**: ${prediction_results['predicted_price']:,.2f}
|
| 805 |
+
""")
|
| 806 |
|
| 807 |
if __name__ == "__main__":
|
| 808 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|