Spaces:
Sleeping
Sleeping
Commit
·
abdf1bb
0
Parent(s):
backend commit
Browse files- .env +1 -0
- Dockerfile +20 -0
- README.md +0 -0
- app/data/country_codes.json +175 -0
- app/data/goalscorers.csv +0 -0
- app/data/model/label_encoder.pkl +0 -0
- app/data/model/linear_regression_team1_goals.pkl +0 -0
- app/data/model/linear_regression_team2_goals.pkl +0 -0
- app/data/model/logistic_regression_model.pkl +0 -0
- app/data/model/train_model.py +95 -0
- app/data/results.csv +0 -0
- app/main.py +441 -0
- requirements.txt +8 -0
.env
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
GROQ_API_KEY=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
|
Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use an official Python runtime as a parent image
|
| 2 |
+
FROM python:3.9-slim
|
| 3 |
+
|
| 4 |
+
# Set working directory in the container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Copy the requirements file into the container
|
| 8 |
+
COPY requirements.txt .
|
| 9 |
+
|
| 10 |
+
# Install dependencies
|
| 11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 12 |
+
|
| 13 |
+
# Copy the entire app directory into the container
|
| 14 |
+
COPY app/ .
|
| 15 |
+
|
| 16 |
+
# Expose port 8000 for the FastAPI app
|
| 17 |
+
EXPOSE 8000
|
| 18 |
+
|
| 19 |
+
# Command to run the application
|
| 20 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
ADDED
|
File without changes
|
app/data/country_codes.json
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"Afghanistan": "af",
|
| 3 |
+
"Albania": "al",
|
| 4 |
+
"Algeria": "dz",
|
| 5 |
+
"Andorra": "ad",
|
| 6 |
+
"Angola": "ao",
|
| 7 |
+
"Argentina": "ar",
|
| 8 |
+
"Armenia": "am",
|
| 9 |
+
"Australia": "au",
|
| 10 |
+
"Austria": "at",
|
| 11 |
+
"Azerbaijan": "az",
|
| 12 |
+
"Bahrain": "bh",
|
| 13 |
+
"Bangladesh": "bd",
|
| 14 |
+
"Belarus": "by",
|
| 15 |
+
"Belgium": "be",
|
| 16 |
+
"Benin": "bj",
|
| 17 |
+
"Bhutan": "bt",
|
| 18 |
+
"Bolivia": "bo",
|
| 19 |
+
"Bosnia and Herzegovina": "ba",
|
| 20 |
+
"Botswana": "bw",
|
| 21 |
+
"Brazil": "br",
|
| 22 |
+
"Bulgaria": "bg",
|
| 23 |
+
"Burkina Faso": "bf",
|
| 24 |
+
"Burundi": "bi",
|
| 25 |
+
"Cameroon": "cm",
|
| 26 |
+
"Canada": "ca",
|
| 27 |
+
"Cape Verde": "cv",
|
| 28 |
+
"Central African Republic": "cf",
|
| 29 |
+
"Chad": "td",
|
| 30 |
+
"Chile": "cl",
|
| 31 |
+
"China": "cn",
|
| 32 |
+
"Colombia": "co",
|
| 33 |
+
"Comoros": "km",
|
| 34 |
+
"Congo": "cg",
|
| 35 |
+
"Costa Rica": "cr",
|
| 36 |
+
"Croatia": "hr",
|
| 37 |
+
"Cuba": "cu",
|
| 38 |
+
"Cyprus": "cy",
|
| 39 |
+
"Czech Republic": "cz",
|
| 40 |
+
"Denmark": "dk",
|
| 41 |
+
"Djibouti": "dj",
|
| 42 |
+
"DR Congo": "cd",
|
| 43 |
+
"Ecuador": "ec",
|
| 44 |
+
"Egypt": "eg",
|
| 45 |
+
"El Salvador": "sv",
|
| 46 |
+
"England": "gb-eng",
|
| 47 |
+
"Equatorial Guinea": "gq",
|
| 48 |
+
"Eritrea": "er",
|
| 49 |
+
"Estonia": "ee",
|
| 50 |
+
"Eswatini": "sz",
|
| 51 |
+
"Ethiopia": "et",
|
| 52 |
+
"Fiji": "fj",
|
| 53 |
+
"Finland": "fi",
|
| 54 |
+
"France": "fr",
|
| 55 |
+
"Gabon": "ga",
|
| 56 |
+
"Gambia": "gm",
|
| 57 |
+
"Georgia": "ge",
|
| 58 |
+
"Germany": "de",
|
| 59 |
+
"Ghana": "gh",
|
| 60 |
+
"Greece": "gr",
|
| 61 |
+
"Guatemala": "gt",
|
| 62 |
+
"Guinea": "gn",
|
| 63 |
+
"Guinea-Bissau": "gw",
|
| 64 |
+
"Guyana": "gy",
|
| 65 |
+
"Haiti": "ht",
|
| 66 |
+
"Honduras": "hn",
|
| 67 |
+
"Hungary": "hu",
|
| 68 |
+
"Iceland": "is",
|
| 69 |
+
"India": "in",
|
| 70 |
+
"Indonesia": "id",
|
| 71 |
+
"Iran": "ir",
|
| 72 |
+
"Iraq": "iq",
|
| 73 |
+
"Ireland": "ie",
|
| 74 |
+
"Israel": "il",
|
| 75 |
+
"Italy": "it",
|
| 76 |
+
"Ivory Coast": "ci",
|
| 77 |
+
"Jamaica": "jm",
|
| 78 |
+
"Japan": "jp",
|
| 79 |
+
"Jordan": "jo",
|
| 80 |
+
"Kazakhstan": "kz",
|
| 81 |
+
"Kenya": "ke",
|
| 82 |
+
"Kosovo": "xk",
|
| 83 |
+
"Kuwait": "kw",
|
| 84 |
+
"Kyrgyzstan": "kg",
|
| 85 |
+
"Laos": "la",
|
| 86 |
+
"Latvia": "lv",
|
| 87 |
+
"Lebanon": "lb",
|
| 88 |
+
"Lesotho": "ls",
|
| 89 |
+
"Liberia": "lr",
|
| 90 |
+
"Libya": "ly",
|
| 91 |
+
"Liechtenstein": "li",
|
| 92 |
+
"Lithuania": "lt",
|
| 93 |
+
"Luxembourg": "lu",
|
| 94 |
+
"Madagascar": "mg",
|
| 95 |
+
"Malawi": "mw",
|
| 96 |
+
"Malaysia": "my",
|
| 97 |
+
"Maldives": "mv",
|
| 98 |
+
"Mali": "ml",
|
| 99 |
+
"Malta": "mt",
|
| 100 |
+
"Mauritania": "mr",
|
| 101 |
+
"Mauritius": "mu",
|
| 102 |
+
"Mexico": "mx",
|
| 103 |
+
"Moldova": "md",
|
| 104 |
+
"Monaco": "mc",
|
| 105 |
+
"Mongolia": "mn",
|
| 106 |
+
"Montenegro": "me",
|
| 107 |
+
"Morocco": "ma",
|
| 108 |
+
"Mozambique": "mz",
|
| 109 |
+
"Myanmar": "mm",
|
| 110 |
+
"Namibia": "na",
|
| 111 |
+
"Nepal": "np",
|
| 112 |
+
"Netherlands": "nl",
|
| 113 |
+
"New Zealand": "nz",
|
| 114 |
+
"Nicaragua": "ni",
|
| 115 |
+
"Niger": "ne",
|
| 116 |
+
"Nigeria": "ng",
|
| 117 |
+
"North Korea": "kp",
|
| 118 |
+
"North Macedonia": "mk",
|
| 119 |
+
"Norway": "no",
|
| 120 |
+
"Oman": "om",
|
| 121 |
+
"Pakistan": "pk",
|
| 122 |
+
"Palestine": "ps",
|
| 123 |
+
"Panama": "pa",
|
| 124 |
+
"Papua New Guinea": "pg",
|
| 125 |
+
"Paraguay": "py",
|
| 126 |
+
"Peru": "pe",
|
| 127 |
+
"Philippines": "ph",
|
| 128 |
+
"Poland": "pl",
|
| 129 |
+
"Portugal": "pt",
|
| 130 |
+
"Qatar": "qa",
|
| 131 |
+
"Romania": "ro",
|
| 132 |
+
"Russia": "ru",
|
| 133 |
+
"Rwanda": "rw",
|
| 134 |
+
"San Marino": "sm",
|
| 135 |
+
"Saudi Arabia": "sa",
|
| 136 |
+
"Scotland": "gb-sct",
|
| 137 |
+
"Senegal": "sn",
|
| 138 |
+
"Serbia": "rs",
|
| 139 |
+
"Seychelles": "sc",
|
| 140 |
+
"Sierra Leone": "sl",
|
| 141 |
+
"Singapore": "sg",
|
| 142 |
+
"Slovakia": "sk",
|
| 143 |
+
"Slovenia": "si",
|
| 144 |
+
"Somalia": "so",
|
| 145 |
+
"South Africa": "za",
|
| 146 |
+
"South Korea": "kr",
|
| 147 |
+
"South Sudan": "ss",
|
| 148 |
+
"Spain": "es",
|
| 149 |
+
"Sri Lanka": "lk",
|
| 150 |
+
"Sudan": "sd",
|
| 151 |
+
"Suriname": "sr",
|
| 152 |
+
"Sweden": "se",
|
| 153 |
+
"Switzerland": "ch",
|
| 154 |
+
"Syria": "sy",
|
| 155 |
+
"Tajikistan": "tj",
|
| 156 |
+
"Tanzania": "tz",
|
| 157 |
+
"Thailand": "th",
|
| 158 |
+
"Togo": "tg",
|
| 159 |
+
"Trinidad and Tobago": "tt",
|
| 160 |
+
"Tunisia": "tn",
|
| 161 |
+
"Turkey": "tr",
|
| 162 |
+
"Turkmenistan": "tm",
|
| 163 |
+
"Uganda": "ug",
|
| 164 |
+
"Ukraine": "ua",
|
| 165 |
+
"United Arab Emirates": "ae",
|
| 166 |
+
"United States": "us",
|
| 167 |
+
"Uruguay": "uy",
|
| 168 |
+
"Uzbekistan": "uz",
|
| 169 |
+
"Venezuela": "ve",
|
| 170 |
+
"Vietnam": "vn",
|
| 171 |
+
"Wales": "gb-wls",
|
| 172 |
+
"Yemen": "ye",
|
| 173 |
+
"Zambia": "zm",
|
| 174 |
+
"Zimbabwe": "zw"
|
| 175 |
+
}
|
app/data/goalscorers.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app/data/model/label_encoder.pkl
ADDED
|
Binary file (6.22 kB). View file
|
|
|
app/data/model/linear_regression_team1_goals.pkl
ADDED
|
Binary file (905 Bytes). View file
|
|
|
app/data/model/linear_regression_team2_goals.pkl
ADDED
|
Binary file (905 Bytes). View file
|
|
|
app/data/model/logistic_regression_model.pkl
ADDED
|
Binary file (1.26 kB). View file
|
|
|
app/data/model/train_model.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model/train_model.py
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
from sklearn.linear_model import LogisticRegression, LinearRegression
|
| 5 |
+
from sklearn.preprocessing import LabelEncoder
|
| 6 |
+
import joblib
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
# Set up logging
|
| 11 |
+
logging.basicConfig(level=logging.INFO)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
# Load datasets
|
| 15 |
+
try:
|
| 16 |
+
matches_df = pd.read_csv('data/results.csv')
|
| 17 |
+
except FileNotFoundError as e:
|
| 18 |
+
logger.error(f"Dataset not found: {e}")
|
| 19 |
+
raise
|
| 20 |
+
|
| 21 |
+
matches_df['home_score'] = pd.to_numeric(matches_df['home_score'], errors='coerce').fillna(0)
|
| 22 |
+
matches_df['away_score'] = pd.to_numeric(matches_df['away_score'], errors='coerce').fillna(0)
|
| 23 |
+
|
| 24 |
+
# Define the training function
|
| 25 |
+
def train_and_save_models():
|
| 26 |
+
# --- Prepare Data ---
|
| 27 |
+
# Create a symmetric outcome: 0 for team1 win, 1 for draw, 2 for team2 win
|
| 28 |
+
def get_match_outcome(row):
|
| 29 |
+
if row['home_score'] > row['away_score']:
|
| 30 |
+
return 0 if row['home_team'] < row['away_team'] else 2
|
| 31 |
+
elif row['home_score'] < row['away_score']:
|
| 32 |
+
return 2 if row['home_team'] < row['away_team'] else 0
|
| 33 |
+
else:
|
| 34 |
+
return 1
|
| 35 |
+
|
| 36 |
+
matches_df['outcome'] = matches_df.apply(get_match_outcome, axis=1)
|
| 37 |
+
|
| 38 |
+
# Sort teams alphabetically to ensure consistency
|
| 39 |
+
matches_df['team1'] = matches_df.apply(lambda x: min(x['home_team'], x['away_team']), axis=1)
|
| 40 |
+
matches_df['team2'] = matches_df.apply(lambda x: max(x['home_team'], x['away_team']), axis=1)
|
| 41 |
+
|
| 42 |
+
# Get all unique team names from both home_team and away_team
|
| 43 |
+
all_teams = pd.concat([matches_df['home_team'], matches_df['away_team']]).unique()
|
| 44 |
+
|
| 45 |
+
# Encode team names with a single LabelEncoder fitted on all teams
|
| 46 |
+
le_outcome = LabelEncoder()
|
| 47 |
+
le_outcome.fit(all_teams) # Fit on all unique teams
|
| 48 |
+
|
| 49 |
+
# --- Logistic Regression for Match Outcome ---
|
| 50 |
+
X_outcome = pd.DataFrame({
|
| 51 |
+
'team1': le_outcome.transform(matches_df['team1']),
|
| 52 |
+
'team2': le_outcome.transform(matches_df['team2'])
|
| 53 |
+
})
|
| 54 |
+
y_outcome = matches_df['outcome']
|
| 55 |
+
|
| 56 |
+
# Split data and train Logistic Regression model
|
| 57 |
+
X_train_outcome, _, y_train_outcome, _ = train_test_split(X_outcome, y_outcome, test_size=0.2, random_state=42)
|
| 58 |
+
logistic_model = LogisticRegression(multi_class='multinomial', max_iter=1000)
|
| 59 |
+
logistic_model.fit(X_train_outcome, y_train_outcome)
|
| 60 |
+
|
| 61 |
+
# --- Linear Regression for Goal Prediction ---
|
| 62 |
+
X_goals = pd.DataFrame({
|
| 63 |
+
'team1': le_outcome.transform(matches_df['team1']),
|
| 64 |
+
'team2': le_outcome.transform(matches_df['team2'])
|
| 65 |
+
})
|
| 66 |
+
|
| 67 |
+
# Targets: home_score and away_score as separate predictions
|
| 68 |
+
y_team1_goals = matches_df['home_score'] # Goals scored by team1 (home team in original data)
|
| 69 |
+
y_team2_goals = matches_df['away_score'] # Goals scored by team2 (away team in original data)
|
| 70 |
+
|
| 71 |
+
# Split data for goal prediction
|
| 72 |
+
X_train_goals, _, y_train_team1_goals, _, y_train_team2_goals, _ = train_test_split(
|
| 73 |
+
X_goals, y_team1_goals, y_team2_goals, test_size=0.2, random_state=42
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Train Linear Regression models for each team's goals
|
| 77 |
+
linear_model_team1 = LinearRegression()
|
| 78 |
+
linear_model_team2 = LinearRegression()
|
| 79 |
+
|
| 80 |
+
linear_model_team1.fit(X_train_goals, y_train_team1_goals)
|
| 81 |
+
linear_model_team2.fit(X_train_goals, y_train_team2_goals)
|
| 82 |
+
|
| 83 |
+
# Ensure the model directory exists
|
| 84 |
+
os.makedirs('model', exist_ok=True)
|
| 85 |
+
|
| 86 |
+
# Save all models and the label encoder
|
| 87 |
+
joblib.dump(logistic_model, 'model/logistic_regression_model.pkl')
|
| 88 |
+
joblib.dump(linear_model_team1, 'model/linear_regression_team1_goals.pkl')
|
| 89 |
+
joblib.dump(linear_model_team2, 'model/linear_regression_team2_goals.pkl')
|
| 90 |
+
joblib.dump(le_outcome, 'model/label_encoder.pkl')
|
| 91 |
+
|
| 92 |
+
logger.info("Logistic Regression and Linear Regression models, along with LabelEncoder, saved successfully.")
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
train_and_save_models()
|
app/data/results.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app/main.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import uvicorn
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
import logging
|
| 7 |
+
import joblib
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
from groq import Groq
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
|
| 14 |
+
# Load environment variables
|
| 15 |
+
load_dotenv()
|
| 16 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 17 |
+
if not GROQ_API_KEY:
|
| 18 |
+
raise Exception("GROQ_API_KEY not found in environment variables.")
|
| 19 |
+
|
| 20 |
+
client = Groq(api_key=GROQ_API_KEY)
|
| 21 |
+
logging.basicConfig(level=logging.INFO)
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
app = FastAPI()
|
| 25 |
+
|
| 26 |
+
# Enable CORS to allow frontend communication
|
| 27 |
+
app.add_middleware(
|
| 28 |
+
CORSMiddleware,
|
| 29 |
+
allow_origins=["*"], # Adjust to specific frontend URL in production
|
| 30 |
+
allow_credentials=True,
|
| 31 |
+
allow_methods=["*"],
|
| 32 |
+
allow_headers=["*"],
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Load datasets and country codes
|
| 36 |
+
try:
|
| 37 |
+
matches_df = pd.read_csv('data/results.csv')
|
| 38 |
+
goals_df = pd.read_csv('data/goalscorers.csv')
|
| 39 |
+
with open('data/country_codes.json', 'r') as f:
|
| 40 |
+
COUNTRY_CODE_MAP = json.load(f)
|
| 41 |
+
except FileNotFoundError as e:
|
| 42 |
+
logger.error(f"File not found: {e}")
|
| 43 |
+
raise HTTPException(status_code=500, detail="Data files not found or inaccessible")
|
| 44 |
+
except pd.errors.EmptyDataError as e:
|
| 45 |
+
logger.error(f"CSV files are empty: {e}")
|
| 46 |
+
raise HTTPException(status_code=500, detail="Data files are empty or invalid")
|
| 47 |
+
|
| 48 |
+
matches_df['home_score'] = pd.to_numeric(matches_df['home_score'], errors='coerce').fillna(0)
|
| 49 |
+
matches_df['away_score'] = pd.to_numeric(matches_df['away_score'], errors='coerce').fillna(0)
|
| 50 |
+
|
| 51 |
+
np.random.seed(42)
|
| 52 |
+
goals_df['x_coord'] = np.where(
|
| 53 |
+
goals_df['team'] == goals_df['home_team'],
|
| 54 |
+
np.random.uniform(80, 100, len(goals_df)).round(),
|
| 55 |
+
np.random.uniform(0, 20, len(goals_df)).round()
|
| 56 |
+
)
|
| 57 |
+
goals_df['y_coord'] = np.random.uniform(20, 80, len(goals_df)).round()
|
| 58 |
+
|
| 59 |
+
teams = set(matches_df['home_team'].unique()).union(set(matches_df['away_team'].unique()))
|
| 60 |
+
players = sorted([str(scorer) for scorer in goals_df['scorer'].dropna().unique() if pd.notna(scorer)])
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
logistic_model = joblib.load('model/logistic_regression_model.pkl')
|
| 64 |
+
linear_model_team1 = joblib.load('model/linear_regression_team1_goals.pkl')
|
| 65 |
+
linear_model_team2 = joblib.load('model/linear_regression_team2_goals.pkl')
|
| 66 |
+
le = joblib.load('model/label_encoder.pkl')
|
| 67 |
+
logger.info("Models loaded successfully.")
|
| 68 |
+
except FileNotFoundError as e:
|
| 69 |
+
logger.error(f"Model files not found: {e}")
|
| 70 |
+
raise HTTPException(status_code=500, detail="Trained model files not found.")
|
| 71 |
+
|
| 72 |
+
def summarize_with_groq(text):
|
| 73 |
+
try:
|
| 74 |
+
chat_completion = client.chat.completions.create(
|
| 75 |
+
messages=[
|
| 76 |
+
{"role": "system", "content": "You are a helpful assistant that provides concise summaries."},
|
| 77 |
+
{"role": "user", "content": f"Summarize the following text:\n\n{text}"}
|
| 78 |
+
],
|
| 79 |
+
model="llama-3.3-70b-versatile",
|
| 80 |
+
max_tokens=150
|
| 81 |
+
)
|
| 82 |
+
return chat_completion.choices[0].message.content
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"Error summarizing with Groq: {e}")
|
| 85 |
+
return "Summary unavailable due to an error."
|
| 86 |
+
|
| 87 |
+
def get_team_stats(team_name):
|
| 88 |
+
home_matches = matches_df[matches_df['home_team'] == team_name]
|
| 89 |
+
away_matches = matches_df[matches_df['away_team'] == team_name]
|
| 90 |
+
|
| 91 |
+
if home_matches.empty and away_matches.empty:
|
| 92 |
+
return {
|
| 93 |
+
"total_matches": 0,
|
| 94 |
+
"wins": 0,
|
| 95 |
+
"losses": 0,
|
| 96 |
+
"draws": 0,
|
| 97 |
+
"home_matches_played": 0,
|
| 98 |
+
"away_matches_played": 0,
|
| 99 |
+
"tournament_performance": {},
|
| 100 |
+
"country_code": COUNTRY_CODE_MAP.get(team_name, "unknown")
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
total_matches = len(home_matches) + len(away_matches)
|
| 104 |
+
wins = len(home_matches[home_matches['home_score'] > home_matches['away_score']]) + \
|
| 105 |
+
len(away_matches[away_matches['away_score'] > away_matches['home_score']])
|
| 106 |
+
losses = len(home_matches[home_matches['home_score'] < home_matches['away_score']]) + \
|
| 107 |
+
len(away_matches[away_matches['away_score'] < away_matches['home_score']])
|
| 108 |
+
draws = len(home_matches[home_matches['home_score'] == home_matches['away_score']]) + \
|
| 109 |
+
len(away_matches[away_matches['away_score'] == away_matches['home_score']])
|
| 110 |
+
|
| 111 |
+
all_matches = pd.concat([home_matches, away_matches])
|
| 112 |
+
tournament_stats = {}
|
| 113 |
+
|
| 114 |
+
for tournament in all_matches['tournament'].unique():
|
| 115 |
+
tourn_matches = all_matches[all_matches['tournament'] == tournament]
|
| 116 |
+
tourn_wins = len(tourn_matches[
|
| 117 |
+
((tourn_matches['home_team'] == team_name) & (tourn_matches['home_score'] > tourn_matches['away_score'])) |
|
| 118 |
+
((tourn_matches['away_team'] == team_name) & (tourn_matches['away_score'] > tourn_matches['home_score']))
|
| 119 |
+
])
|
| 120 |
+
tourn_losses = len(tourn_matches[
|
| 121 |
+
((tourn_matches['home_team'] == team_name) & (tourn_matches['home_score'] < tourn_matches['away_score'])) |
|
| 122 |
+
((tourn_matches['away_team'] == team_name) & (tourn_matches['away_score'] < tourn_matches['home_score']))
|
| 123 |
+
])
|
| 124 |
+
tourn_draws = len(tourn_matches[tourn_matches['home_score'] == tourn_matches['away_score']])
|
| 125 |
+
tourn_total = tourn_wins + tourn_losses + tourn_draws
|
| 126 |
+
tournament_stats[tournament] = {
|
| 127 |
+
"matches_played": tourn_total,
|
| 128 |
+
"wins": tourn_wins,
|
| 129 |
+
"losses": tourn_losses,
|
| 130 |
+
"draws": tourn_draws,
|
| 131 |
+
"win_percentage": round((tourn_wins / tourn_total * 100), 2) if tourn_total > 0 else 0.0
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
return {
|
| 135 |
+
"total_matches": total_matches,
|
| 136 |
+
"wins": wins,
|
| 137 |
+
"losses": losses,
|
| 138 |
+
"draws": draws,
|
| 139 |
+
"home_matches_played": len(home_matches),
|
| 140 |
+
"away_matches_played": len(away_matches),
|
| 141 |
+
"tournament_performance": tournament_stats,
|
| 142 |
+
"country_code": COUNTRY_CODE_MAP.get(team_name, "unknown")
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
def get_match_goalscorers(date, home_team, away_team):
|
| 146 |
+
match_goals = goals_df[(goals_df['date'] == date) &
|
| 147 |
+
(goals_df['home_team'] == home_team) &
|
| 148 |
+
(goals_df['away_team'] == away_team)]
|
| 149 |
+
return match_goals[['scorer', 'minute', 'team', 'own_goal', 'penalty']].to_dict('records')
|
| 150 |
+
|
| 151 |
+
def get_head_to_head_stats(team1, team2, num_matches=5):
|
| 152 |
+
matches = matches_df[((matches_df['home_team'] == team1) & (matches_df['away_team'] == team2)) |
|
| 153 |
+
((matches_df['home_team'] == team2) & (matches_df['away_team'] == team1))]
|
| 154 |
+
|
| 155 |
+
if matches.empty:
|
| 156 |
+
return {"total_matches": 0, f"{team1}_wins": 0, f"{team2}_wins": 0, "draws": 0,
|
| 157 |
+
f"{team1}_goals": 0, f"{team2}_goals": 0, "goal_difference": "Even",
|
| 158 |
+
"last_matches": [], "chart": None}
|
| 159 |
+
|
| 160 |
+
total_matches = len(matches)
|
| 161 |
+
team1_wins = len(matches[((matches['home_team'] == team1) & (matches['home_score'] > matches['away_score'])) |
|
| 162 |
+
((matches['away_team'] == team1) & (matches['away_score'] > matches['home_score']))])
|
| 163 |
+
team2_wins = len(matches[((matches['home_team'] == team2) & (matches['home_score'] > matches['away_score'])) |
|
| 164 |
+
((matches['away_team'] == team2) & (matches['away_score'] > matches['home_score']))])
|
| 165 |
+
draws = len(matches[matches['home_score'] == matches['away_score']])
|
| 166 |
+
team1_goals = matches[matches['home_team'] == team1]['home_score'].sum() + \
|
| 167 |
+
matches[matches['away_team'] == team1]['away_score'].sum()
|
| 168 |
+
team2_goals = matches[matches['home_team'] == team2]['home_score'].sum() + \
|
| 169 |
+
matches[matches['away_team'] == team2]['away_score'].sum()
|
| 170 |
+
|
| 171 |
+
goal_diff = team1_goals - team2_goals
|
| 172 |
+
goal_difference_str = f"{team1} +{int(goal_diff)}" if goal_diff > 0 else \
|
| 173 |
+
f"{team2} +{int(abs(goal_diff))}" if goal_diff < 0 else "Even"
|
| 174 |
+
|
| 175 |
+
last_n_matches = matches.tail(num_matches)
|
| 176 |
+
last_n_results = []
|
| 177 |
+
for _, match in last_n_matches.iterrows():
|
| 178 |
+
goalscorers = get_match_goalscorers(match['date'], match['home_team'], match['away_team'])
|
| 179 |
+
last_n_results.append({
|
| 180 |
+
"date": match['date'], "home_team": match['home_team'], "away_team": match['away_team'],
|
| 181 |
+
"home_score": int(match['home_score']), "away_score": int(match['away_score']),
|
| 182 |
+
"tournament": match['tournament'], "goalscorers": goalscorers
|
| 183 |
+
})
|
| 184 |
+
|
| 185 |
+
total_wins = team1_wins + team2_wins
|
| 186 |
+
win_prop_team1 = team1_wins / total_wins if total_wins > 0 else 0
|
| 187 |
+
win_prop_team2 = team2_wins / total_wins if total_wins > 0 else 0
|
| 188 |
+
total_goals = team1_goals + team2_goals
|
| 189 |
+
goal_prop_team1 = team1_goals / total_goals if total_goals > 0 else 0
|
| 190 |
+
goal_prop_team2 = team2_goals / total_goals if total_goals > 0 else 0
|
| 191 |
+
goal_diff_value = int(abs(goal_diff))
|
| 192 |
+
goal_diff_prop_team1 = goal_diff_value / (goal_diff_value + 1) if goal_diff_value > 0 else 0.5
|
| 193 |
+
goal_diff_prop_team2 = 1 - goal_diff_prop_team1 if goal_diff_value > 0 else 0.5
|
| 194 |
+
|
| 195 |
+
fig = go.Figure(data=[
|
| 196 |
+
go.Bar(name=team1, x=[win_prop_team1, goal_prop_team1, goal_diff_prop_team1], y=['Wins', 'Goals', 'Goal Difference'], orientation='h', marker_color='teal'),
|
| 197 |
+
go.Bar(name=team2, x=[win_prop_team2, goal_prop_team2, goal_diff_prop_team2], y=['Wins', 'Goals', 'Goal Difference'], orientation='h', marker_color='orange')
|
| 198 |
+
])
|
| 199 |
+
fig.update_layout(barmode='stack', title_text=f'Proportion of {team1} vs {team2}', xaxis_title="Proportion", yaxis_title="Categories", xaxis=dict(range=[0, 1]))
|
| 200 |
+
|
| 201 |
+
return {
|
| 202 |
+
"total_matches": total_matches, f"{team1}_wins": team1_wins, f"{team2}_wins": team2_wins, "draws": draws,
|
| 203 |
+
f"{team1}_goals": int(team1_goals), f"{team2}_goals": int(team2_goals), "goal_difference": goal_difference_str,
|
| 204 |
+
"last_matches": last_n_results, "chart": fig.to_json()
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
def get_player_stats(player_name):
|
| 208 |
+
player_goals = goals_df[goals_df['scorer'] == player_name]
|
| 209 |
+
if player_goals.empty:
|
| 210 |
+
raise HTTPException(status_code=404, detail="Player not found")
|
| 211 |
+
total_goals = len(player_goals[player_goals['own_goal'] == False])
|
| 212 |
+
player_team = player_goals['team'].mode()[0] if not player_goals['team'].empty else "Unknown"
|
| 213 |
+
return {"player_name": player_name, "country": player_team, "total_goals": total_goals}
|
| 214 |
+
|
| 215 |
+
def predict_match_outcome(team1, team2):
|
| 216 |
+
try:
|
| 217 |
+
teams_sorted = sorted([team1, team2])
|
| 218 |
+
team1_encoded = le.transform([teams_sorted[0]])[0] if teams_sorted[0] in le.classes_ else -1
|
| 219 |
+
team2_encoded = le.transform([teams_sorted[1]])[0] if teams_sorted[1] in le.classes_ else -1
|
| 220 |
+
|
| 221 |
+
if team1_encoded == -1 or team2_encoded == -1:
|
| 222 |
+
raise ValueError("One or both teams not found in training data")
|
| 223 |
+
|
| 224 |
+
X_pred = [[team1_encoded, team2_encoded]]
|
| 225 |
+
|
| 226 |
+
probs = logistic_model.predict_proba(X_pred)[0]
|
| 227 |
+
|
| 228 |
+
if team1 < team2:
|
| 229 |
+
outcome_probs = {
|
| 230 |
+
"team1_win": round(probs[0] * 100, 2),
|
| 231 |
+
"team2_win": round(probs[2] * 100, 2),
|
| 232 |
+
"draw": round(probs[1] * 100, 2)
|
| 233 |
+
}
|
| 234 |
+
else:
|
| 235 |
+
outcome_probs = {
|
| 236 |
+
"team1_win": round(probs[2] * 100, 2),
|
| 237 |
+
"team2_win": round(probs[0] * 100, 2),
|
| 238 |
+
"draw": round(probs[1] * 100, 2)
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
if outcome_probs["team1_win"] > outcome_probs["team2_win"] and outcome_probs["team1_win"] >= outcome_probs["draw"]:
|
| 242 |
+
goals_pred = {"team1_goals": 2, "team2_goals": 1}
|
| 243 |
+
elif outcome_probs["team2_win"] > outcome_probs["team1_win"] and outcome_probs["team2_win"] >= outcome_probs["draw"]:
|
| 244 |
+
goals_pred = {"team1_goals": 1, "team2_goals": 2}
|
| 245 |
+
else:
|
| 246 |
+
goals_pred = {"team1_goals": 1, "team2_goals": 1}
|
| 247 |
+
|
| 248 |
+
return {
|
| 249 |
+
"outcome_probabilities": outcome_probs,
|
| 250 |
+
"predicted_goals": goals_pred
|
| 251 |
+
}
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logger.error(f"Prediction error for {team1} vs {team2}: {str(e)}")
|
| 254 |
+
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
|
| 255 |
+
|
| 256 |
+
@app.get("/")
|
| 257 |
+
async def home():
|
| 258 |
+
return {
|
| 259 |
+
"message": "Welcome to Football Prediction API",
|
| 260 |
+
"description": "This API provides football statistics, match predictions, and data visualizations",
|
| 261 |
+
"available_endpoints": {
|
| 262 |
+
"/teams": "List all teams",
|
| 263 |
+
"/players": "List all players",
|
| 264 |
+
"/country-codes": "Get country codes",
|
| 265 |
+
"/team/{team_name}": "Get team statistics",
|
| 266 |
+
"/head-to-head/{team1}/{team2}": "Get head-to-head statistics",
|
| 267 |
+
"/player/{player_name}": "Get player statistics",
|
| 268 |
+
"/predict/{team1}/{team2}": "Predict match outcome",
|
| 269 |
+
"/goal-spatial-heatmap/{team}": "Get goal distribution heatmap"
|
| 270 |
+
}
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
@app.get("/teams")
|
| 274 |
+
async def get_teams():
|
| 275 |
+
return {"teams": sorted(list(teams))}
|
| 276 |
+
|
| 277 |
+
@app.get("/players")
|
| 278 |
+
async def get_players():
|
| 279 |
+
return {"players": players}
|
| 280 |
+
|
| 281 |
+
@app.get("/country-codes")
|
| 282 |
+
async def get_country_codes():
|
| 283 |
+
return COUNTRY_CODE_MAP
|
| 284 |
+
|
| 285 |
+
@app.get("/team/{team_name}")
|
| 286 |
+
async def get_team_statistics(team_name: str, summarize: bool = False):
|
| 287 |
+
if team_name not in teams:
|
| 288 |
+
raise HTTPException(status_code=404, detail=f"Team {team_name} not found")
|
| 289 |
+
try:
|
| 290 |
+
stats = get_team_stats(team_name)
|
| 291 |
+
except Exception as e:
|
| 292 |
+
logger.error(f"Error calculating stats for {team_name}: {str(e)}")
|
| 293 |
+
raise HTTPException(status_code=500, detail=f"Error calculating stats: {str(e)}")
|
| 294 |
+
|
| 295 |
+
response = {"team": team_name, "statistics": stats}
|
| 296 |
+
|
| 297 |
+
if summarize:
|
| 298 |
+
basic_stats_text = "\n".join([f"{key}: {value}" for key, value in stats.items() if key != "tournament_performance"])
|
| 299 |
+
tournament_text = "\nTournament Performance:\n" + "\n".join(
|
| 300 |
+
[f"{tourn}: Matches: {stats['tournament_performance'][tourn]['matches_played']}, "
|
| 301 |
+
f"Wins: {stats['tournament_performance'][tourn]['wins']}, "
|
| 302 |
+
f"Losses: {stats['tournament_performance'][tourn]['losses']}, "
|
| 303 |
+
f"Draws: {stats['tournament_performance'][tourn]['draws']}, "
|
| 304 |
+
f"Win%: {stats['tournament_performance'][tourn]['win_percentage']}%"
|
| 305 |
+
for tourn in stats['tournament_performance']]
|
| 306 |
+
)
|
| 307 |
+
full_text = f"{basic_stats_text}\n{tournament_text}"
|
| 308 |
+
summary = summarize_with_groq(full_text)
|
| 309 |
+
response["summary"] = summary
|
| 310 |
+
|
| 311 |
+
return response
|
| 312 |
+
|
| 313 |
+
@app.get("/head-to-head/{team1}/{team2}")
|
| 314 |
+
async def get_head_to_head(team1: str, team2: str, num_matches: int = 5, summarize: bool = False):
|
| 315 |
+
if team1 not in teams or team2 not in teams:
|
| 316 |
+
raise HTTPException(status_code=404, detail="One or both teams not found")
|
| 317 |
+
if num_matches < 0:
|
| 318 |
+
raise HTTPException(status_code=400, detail="Number of matches must be non-negative")
|
| 319 |
+
stats = get_head_to_head_stats(team1, team2, num_matches)
|
| 320 |
+
response = {"team1": team1, "team2": team2, "head_to_head_statistics": stats}
|
| 321 |
+
if summarize:
|
| 322 |
+
text = "\n".join([f"{key}: {value}" for key, value in stats.items() if key not in ["last_matches", "chart"]] +
|
| 323 |
+
[f"Last Match: {match['date']} - {match['home_team']} {match['home_score']} vs {match['away_score']} {match['away_team']}"
|
| 324 |
+
for match in stats["last_matches"]])
|
| 325 |
+
summary = summarize_with_groq(text)
|
| 326 |
+
response["summary"] = summary
|
| 327 |
+
return response
|
| 328 |
+
|
| 329 |
+
@app.get("/player/{player_name}")
|
| 330 |
+
async def get_player_statistics(player_name: str, summarize: bool = False):
|
| 331 |
+
stats = get_player_stats(player_name)
|
| 332 |
+
response = stats
|
| 333 |
+
if summarize:
|
| 334 |
+
text = "\n".join([f"{key}: {value}" for key, value in stats.items()])
|
| 335 |
+
summary = summarize_with_groq(text)
|
| 336 |
+
response["summary"] = summary
|
| 337 |
+
return response
|
| 338 |
+
|
| 339 |
+
@app.get("/predict/{team1}/{team2}")
|
| 340 |
+
async def predict_match(team1: str, team2: str, summarize: bool = False):
|
| 341 |
+
if team1 not in teams or team2 not in teams:
|
| 342 |
+
raise HTTPException(status_code=404, detail="One or both teams not found")
|
| 343 |
+
predictions = predict_match_outcome(team1, team2)
|
| 344 |
+
response = {"team1": team1, "team2": team2, "predictions": predictions}
|
| 345 |
+
if summarize:
|
| 346 |
+
text = (f"Outcome Probabilities: {team1} Win: {predictions['outcome_probabilities']['team1_win']}%, "
|
| 347 |
+
f"{team2} Win: {predictions['outcome_probabilities']['team2_win']}%, Draw: {predictions['outcome_probabilities']['draw']}%\n"
|
| 348 |
+
f"Predicted Goals: {team1}: {predictions['predicted_goals']['team1_goals']}, {team2}: {predictions['predicted_goals']['team2_goals']}")
|
| 349 |
+
summary = summarize_with_groq(text)
|
| 350 |
+
response["summary"] = summary
|
| 351 |
+
return response
|
| 352 |
+
|
| 353 |
+
@app.get("/goal-spatial-heatmap/{team}")
|
| 354 |
+
async def get_goal_spatial_heatmap(team: str, start_year: int = 2000, end_year: int = 2023, summarize: bool = False):
|
| 355 |
+
if team not in teams:
|
| 356 |
+
raise HTTPException(status_code=404, detail=f"Team {team} not found")
|
| 357 |
+
|
| 358 |
+
if start_year > end_year:
|
| 359 |
+
raise HTTPException(status_code=400, detail="start_year must be less than or equal to end_year")
|
| 360 |
+
|
| 361 |
+
try:
|
| 362 |
+
matches_df['date'] = pd.to_datetime(matches_df['date'])
|
| 363 |
+
goals_df['date'] = pd.to_datetime(goals_df['date'])
|
| 364 |
+
|
| 365 |
+
team_matches = matches_df[
|
| 366 |
+
((matches_df['home_team'] == team) | (matches_df['away_team'] == team)) &
|
| 367 |
+
(matches_df['date'].dt.year >= start_year) & (matches_df['date'].dt.year <= end_year)
|
| 368 |
+
]
|
| 369 |
+
|
| 370 |
+
team_goals = goals_df[
|
| 371 |
+
(goals_df['team'] == team) &
|
| 372 |
+
(goals_df['date'].dt.year >= start_year) & (goals_df['date'].dt.year <= end_year)
|
| 373 |
+
].dropna(subset=['x_coord', 'y_coord'])
|
| 374 |
+
|
| 375 |
+
if team_goals.empty:
|
| 376 |
+
raise HTTPException(status_code=404, detail=f"No goal data found for {team} in the specified year range")
|
| 377 |
+
|
| 378 |
+
heatmap_data, xedges, yedges = np.histogram2d(
|
| 379 |
+
team_goals['x_coord'],
|
| 380 |
+
team_goals['y_coord'],
|
| 381 |
+
bins=50,
|
| 382 |
+
range=[[0, 100], [0, 100]]
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
heatmap_data = heatmap_data / heatmap_data.max() if heatmap_data.max() > 0 else heatmap_data
|
| 386 |
+
|
| 387 |
+
fig = go.Figure(data=go.Heatmap(
|
| 388 |
+
z=heatmap_data.T,
|
| 389 |
+
x=xedges,
|
| 390 |
+
y=yedges,
|
| 391 |
+
colorscale='Viridis',
|
| 392 |
+
colorbar=dict(title='Goal Density'),
|
| 393 |
+
zmin=0,
|
| 394 |
+
zmax=1
|
| 395 |
+
))
|
| 396 |
+
|
| 397 |
+
fig.add_shape(type="rect", x0=0, y0=0, x1=100, y1=100, line=dict(color="white", width=2))
|
| 398 |
+
fig.add_shape(type="rect", x0=0, y0=20, x1=16, y1=80, line=dict(color="white", width=2))
|
| 399 |
+
fig.add_shape(type="rect", x0=84, y0=20, x1=100, y1=80, line=dict(color="white", width=2))
|
| 400 |
+
fig.add_shape(type="rect", x0=0, y0=40, x1=5, y1=60, line=dict(color="white", width=2))
|
| 401 |
+
fig.add_shape(type="rect", x0=95, y0=40, x1=100, y1=60, line=dict(color="white", width=2))
|
| 402 |
+
fig.add_shape(type="circle", x0=45, y0=45, x1=55, y1=55, line=dict(color="white", width=2))
|
| 403 |
+
fig.add_shape(type="line", x0=50, y0=0, x1=50, y1=100, line=dict(color="white", width=2))
|
| 404 |
+
|
| 405 |
+
fig.update_layout(
|
| 406 |
+
title=f'Goal Distribution Heatmap for {team} ({start_year}-{end_year})',
|
| 407 |
+
xaxis_title='X Position (Length of Pitch)',
|
| 408 |
+
yaxis_title='Y Position (Width of Pitch)',
|
| 409 |
+
xaxis=dict(range=[0, 100], tickvals=[0, 20, 40, 60, 80, 100], showgrid=False),
|
| 410 |
+
yaxis=dict(range=[0, 100], tickvals=[0, 20, 40, 60, 80, 100], showgrid=False),
|
| 411 |
+
template="plotly_dark",
|
| 412 |
+
width=800,
|
| 413 |
+
height=500,
|
| 414 |
+
plot_bgcolor='rgba(0,128,0,0.3)',
|
| 415 |
+
paper_bgcolor='rgba(0,0,0,0)'
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
response = {
|
| 419 |
+
"team": team,
|
| 420 |
+
"start_year": start_year,
|
| 421 |
+
"end_year": end_year,
|
| 422 |
+
"heatmap": fig.to_json(),
|
| 423 |
+
"total_goals": len(team_goals),
|
| 424 |
+
"average_goals_per_match": round(len(team_goals) / len(team_matches) if len(team_matches) > 0 else 0, 2)
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
if summarize:
|
| 428 |
+
text = (f"Goal Distribution for {team} ({start_year}-{end_year})\n"
|
| 429 |
+
f"Total Goals: {len(team_goals)}\n"
|
| 430 |
+
f"Average Goals per Match: {response['average_goals_per_match']:.2f}")
|
| 431 |
+
summary = summarize_with_groq(text)
|
| 432 |
+
response["summary"] = summary
|
| 433 |
+
|
| 434 |
+
return response
|
| 435 |
+
|
| 436 |
+
except Exception as e:
|
| 437 |
+
logger.error(f"Error generating spatial heatmap for {team}: {str(e)}")
|
| 438 |
+
raise HTTPException(status_code=500, detail=f"Error generating heatmap: {str(e)}")
|
| 439 |
+
|
| 440 |
+
if __name__ == "__main__":
|
| 441 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.103.1
|
| 2 |
+
uvicorn==0.23.2
|
| 3 |
+
pandas==2.0.3
|
| 4 |
+
plotly==5.15.0
|
| 5 |
+
joblib==1.3.2
|
| 6 |
+
numpy==1.25.2
|
| 7 |
+
groq==0.9.0
|
| 8 |
+
python-dotenv==1.0.0
|