TurkishTrends / app.py
UHRCRU's picture
Update app.py
f62223a verified
import os
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import torch
# --- Load HF token from environment ---
token = os.getenv("HF_TOKEN")
if token is None:
raise ValueError("HF_TOKEN environment variable not set")
# --- Use a better and faster model ---
model_id = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=token,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# --- Load GA4 CSV data ---
def load_ga4_data():
return pd.read_csv("synthetic_ga4_data.csv")
df = load_ga4_data()
prompt = prompt.lower().strip()
if "most users" in prompt:
top_city = df.groupby("City")["Users"].sum().idxmax()
users = df.groupby("City")["Users"].sum().max()
else:
base_answer = "Sorry, I can only currently analyze questions about users and conversion rates."
input_text = f"Rephrase this like a digital marketing analyst for a business report: {base_answer}"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_new_tokens=80,
do_sample=False,
temperature=0.3,
repetition_penalty=1.1
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
return response, "βœ… Insight generated successfully."
except Exception as e:
return f"⚠️ Error: {str(e)}", "❌ Insight generation failed."
# --- City Performance Plot ---
def plot_city_performance():
prompt = gr.Textbox(label="Your Analysis Question", placeholder="Which city has the best conversion rate?")
generate = gr.Button("Generate Insight")
output = gr.Textbox(label="AI Response", interactive=False)
status = gr.Textbox(label="Status", interactive=False, visible=True)
generate.click(fn=generate_insight, inputs=prompt, outputs=[output, status])
gr.Markdown("---\n**Available GA4 Metrics:** Users, Sessions, Transactions, Revenue, Avg Session Duration, etc.")