schat / engine-ddo /services.py
VeuReu's picture
Upload 12 files
ba54b37 verified
import os
import json
from sqlalchemy.orm import Session
from models import Product, CustomerProfile, Conversation, Message
from typing import List
from PyPDF2 import PdfReader
# Product services
def _read_pdf_text(file_path: str) -> str:
try:
reader = PdfReader(file_path)
return "\n".join(page.extract_text() or "" for page in reader.pages)
except Exception:
return ""
def extract_and_upsert_products_from_llm(db: Session, public_pdf_path: str, private_pdf_path: str):
"""Extracts product info from PDFs using an LLM and saves to DB."""
public_text = _read_pdf_text(public_pdf_path)
private_text = _read_pdf_text(private_pdf_path)
if not public_text and not private_text:
# Fallback for demo if PDFs are empty or unreadable
demo_products = [
Product(name="Demo Basic", description="Standard features for small teams.", notes="High churn risk.", price=9.0),
Product(name="Demo Pro", description="Advanced features and priority support.", notes="Stable customer base.", price=39.0),
Product(name="Demo Enterprise", description="Dedicated support and custom integrations.", notes="Potential for expansion.", price=199.0),
]
for p in demo_products:
db.merge(p)
db.commit()
return
# Use the LLM call logic from main.py
from main import llm_reply
system_prompt = """
You are an expert data extractor. Your task is to analyze two documents, a public offering and a private notes document, and extract product information.
Respond with a single JSON array of objects. Each object should represent a product and have the following fields:
- "product": The name of the product.
- "description": The description from the public offering document.
- "notes": Internal notes from the private notes document.
- "price": The price as a numeric value (float), if available.
If you find information that does not belong to a specific product, assign it to a product named "general".
Ensure your output is a valid JSON array.
"""
user_prompt = f"""
Here is the content from the public offering document:
--- PUBLIC OFFERING ---
{public_text}
Here is the content from the private notes document:
--- PRIVATE NOTES ---
{private_text}
Please extract the product information as a JSON array.
"""
# This is a blocking call, so we don't need async here
import asyncio
llm_response = asyncio.run(llm_reply(system_prompt, [], user_prompt))
try:
# Clean the response to get only the JSON part
json_str = llm_response[llm_response.find('['):llm_response.rfind(']')+1]
extracted_data = json.loads(json_str)
for item in extracted_data:
product = Product(
name=item.get("product", "general"),
description=item.get("description"),
notes=item.get("notes"),
price=float(item["price"]) if item.get("price") else None
)
# Use merge to insert or update based on the primary key (name)
db.merge(product)
db.commit()
except (json.JSONDecodeError, TypeError, KeyError) as e:
# Handle cases where LLM output is not as expected
# For demo, we can log the error and maybe insert a placeholder
print(f"Error parsing LLM response: {e}")
placeholder = Product(name="Parsing Error", description="Could not parse data from documents.", notes=str(llm_response))
db.merge(placeholder)
db.commit()
# Customer services
def ensure_default_customers(db: Session):
defaults = [
("random", "Synthetic profile with randomized traits", 0.0, 0),
("SMB buyer", "Budget-conscious, quick decisions", 1200.0, 85),
("Enterprise buyer", "Long sales cycle, security-focused", 24000.0, 12),
]
for name, attrs, w, n in defaults:
row = db.query(CustomerProfile).filter_by(name=name).first()
if not row:
db.add(CustomerProfile(name=name, attributes=attrs, wcltv=w, n=n))
db.commit()
# Chat services
def get_or_create_conversation(db: Session, profile_name: str) -> Conversation:
convo = (
db.query(Conversation)
.filter_by(profile_name=profile_name)
.order_by(Conversation.id.desc())
.first()
)
if not convo:
convo = Conversation(profile_name=profile_name)
db.add(convo)
db.commit()
db.refresh(convo)
return convo
def add_message(db: Session, conversation_id: int, sender: str, text: str):
msg = Message(conversation_id=conversation_id, sender=sender, text=text)
db.add(msg)
db.commit()
def get_history(db: Session, conversation_id: int):
convo = db.query(Conversation).filter_by(id=conversation_id).first()
if not convo:
return []
return [{"sender": m.sender, "text": m.text} for m in convo.messages]