File size: 5,161 Bytes
ba54b37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os
from fastapi import FastAPI, UploadFile, File, Depends, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session
from db import Base, engine, get_db
from models import Product, CustomerProfile
from schemas import (
ProductOut, CustomerOut, CustomerUpdate,
ChatRequest, ChatResponse, ConversationOut
)
from services import extract_and_upsert_products_from_llm, ensure_default_customers, get_or_create_conversation, add_message, get_history
from typing import List
import tempfile
app = FastAPI(title="engine-ddo", openapi_url="/openapi.json")
# CORS for Streamlit UI Space
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize DB
Base.metadata.create_all(bind=engine)
@app.get("/health")
def health():
return {"status": "ok"}
# -------- PRODUCTS --------
@app.post("/products/ingest", response_model=List[ProductOut])
async def ingest_products(public_offering: UploadFile = File(...), private_notes: UploadFile = File(...), db: Session = Depends(get_db)):
# Save temp files to pass paths to the service
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as f1:
f1.write(await public_offering.read())
public_path = f1.name
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as f2:
f2.write(await private_notes.read())
notes_path = f2.name
try:
# Call the new service to process PDFs with an LLM
extract_and_upsert_products_from_llm(db, public_path, notes_path)
finally:
# Clean up the temporary files
os.remove(public_path)
os.remove(notes_path)
# Return all products from the database
rows = db.query(Product).order_by(Product.name.asc()).all()
return rows
@app.get("/products/list", response_model=List[ProductOut])
def list_products(db: Session = Depends(get_db)):
rows = db.query(Product).order_by(Product.name.asc()).all()
return rows
# -------- CUSTOMERS --------
@app.get("/customers/list", response_model=List[CustomerOut])
def list_customers(db: Session = Depends(get_db)):
ensure_default_customers(db)
rows = db.query(CustomerProfile).order_by(CustomerProfile.name.asc()).all()
return rows
@app.post("/customers/update", response_model=CustomerOut)
def update_customer(payload: CustomerUpdate, db: Session = Depends(get_db)):
row = db.query(CustomerProfile).filter_by(name=payload.name).first()
if not row:
row = CustomerProfile(name=payload.name)
db.add(row)
if payload.attributes is not None:
row.attributes = payload.attributes
if payload.wcltv is not None:
row.wcltv = payload.wcltv
if payload.n is not None:
row.n = payload.n
db.commit()
db.refresh(row)
return row
# -------- INTERACTIONS (chat) --------
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
async def llm_reply(system_prompt: str, history: list, user_text: str) -> str:
"""Return a reply from an external LLM if OPENAI_API_KEY set, else a rule-based stub."""
if OPENAI_API_KEY:
try:
from openai import OpenAI
client = OpenAI(api_key=OPENAI_API_KEY)
messages = [{"role": "system", "content": system_prompt}] + history + [{"role": "user", "content": user_text}]
resp = client.chat.completions.create(model=MODEL, messages=messages)
return resp.choices[0].message.content.strip()
except Exception as e:
return f"[LLM error fallback] I couldn't reach the model ({e}). Let's continue anyway."
# Fallback deterministic reply for demo
return "Thanks for the details! Could you share your main need, budget, and timeline? I can match a product for you."
@app.post("/interactions/chat", response_model=ChatResponse)
async def chat(req: ChatRequest, db: Session = Depends(get_db)):
profile = req.profile_name or "random"
convo = get_or_create_conversation(db, profile)
add_message(db, convo.id, sender="customer", text=req.user_text)
# Build history for LLM
hist = []
for turn in get_history(db, convo.id):
role = "user" if turn["sender"] == "customer" else "assistant"
hist.append({"role": role, "content": turn["text"]})
system_prompt = (
"You are a helpful sales assistant. Keep answers short, ask clarifying questions, and reference products generically."
)
reply = await llm_reply(system_prompt, hist, req.user_text)
add_message(db, convo.id, sender="agent", text=reply)
return {"reply": reply, "conversation_id": convo.id}
@app.get("/interactions/history", response_model=ConversationOut)
async def history(profile_name: str, db: Session = Depends(get_db)):
convo = get_or_create_conversation(db, profile_name)
hist = get_history(db, convo.id)
return {"id": convo.id, "profile_name": profile_name, "history": hist}
|