File size: 3,453 Bytes
7cc0170
 
 
 
 
 
 
3e8222a
7cc0170
 
 
 
 
 
 
 
 
3e8222a
 
 
 
 
09effdc
 
 
 
 
 
 
3e8222a
 
 
 
7cc0170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from fastapi import FastAPI, Form
from twilio.rest import Client
from dotenv import load_dotenv
import os
from matching import match_technician, NLPProcessor
from models import Database
from transformers import pipeline
from contextlib import asynccontextmanager

nlp = NLPProcessor()

app = FastAPI()
load_dotenv()

HF_API_URL = "https://api-inference.huggingface.co/models/Christy123/service-classifier"
HF_TOKEN = os.getenv("HF_API_TOKEN")

db = None

@asynccontextmanager
async def lifespan(app):
    global db
    try:
        db = Database()  
        yield
    finally:
        if db:
            db.conn.close() 


app.router.lifespan_context = lifespan


# Twilio client
twilio_client = Client(os.getenv("TWILIO_ACCOUNT_SID"), os.getenv("TWILIO_AUTH_TOKEN"))

def send_message(to_number: str, body: str):
    try:
        if not to_number.startswith("whatsapp:"):
            to_number = f"whatsapp:{to_number.lstrip('+')}"
            
        message = twilio_client.messages.create(
            from_=os.getenv("TWILIO_NUMBER"),
            body=body,
            to="whatsapp:+254792552491"
        )
        return message.sid
    except Exception as e:
        print(f"Twilio error: {str(e)}")
        return None
    
def classify_text(text: str):
    headers = {"Authorization": f"Bearer {HF_TOKEN}"}
    response = requests.post(HF_API_URL, headers=headers, json={"inputs": text})
    
    if response.status_code != 200:
        return {"error": "Model inference failed"}
    
    return response.json()

@app.post("/predict")
async def predict(text: str):
    result = classify_text(text)
    return {"service": result[0]["label"], "confidence": result[0]["score"]}

@app.post("/message")
async def handle_message(From: str = Form(...), Body: str = Form(...)):
    db = Database()
    user_state = db.get_user_state(From)
    
    # Case 1: User is in menu flow (1/2/3 selection)
    if user_state and user_state["state"] == "awaiting_service" and Body.strip() in ["1", "2", "3"]:
        services = {"1": "plumbing", "2": "electrical", "3": "hvac"}
        service_type = services.get(Body.strip())
        db.update_user_state(From, "awaiting_location", service_type)
        db.close()
        send_message(From, "Please share your city or area.")
        return {"status": "awaiting_location"}
    
    # Case 2: User sends free-text request (e.g., "AC repair in Mombasa")
    if not user_state or user_state["state"] == "idle":
        technician, error = match_technician(Body, From)
        if error:
            send_message(From, error)
            return {"status": "error"}
        
        response = (
            f"Found {technician['name']} (Rating: {technician['rating']}/5) "
            f"for your {nlp.extract_service(Body)} request. "
            f"Contact: {technician['contact']}. Confirm? (Yes/No)"
        )
        send_message(From, response)
        db.update_user_state(From, "awaiting_confirmation", response)
        db.close()
        return {"status": "success"}
    
    # Case 3: Handle confirmation/feedback
    db.close()
    return {"status": "processed"}



# main.py
@app.on_event("startup")
async def startup_db():
    db = Database()
    try:
        # Simple query to verify tables
        db.cursor.execute("SELECT 1 FROM users LIMIT 1")
    except psycopg2.Error as e:
        print(f"CRITICAL: Database not initialized. Run schema.sql first")
        raise
    finally:
        db.close()