Techi-Pro2.0 / models.py
Groo12's picture
Add application file
af4689d
import psycopg2
from dotenv import load_dotenv
import os
from functools import lru_cache
from transformers import pipeline
load_dotenv()
class Database:
def __init__(self):
try:
self.conn = psycopg2.connect(os.getenv("POSTGRES_URL"))
self.cursor = self.conn.cursor()
# Verify tables exist on startup
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS messages (
id SERIAL PRIMARY KEY,
sender VARCHAR(50),
body TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
self.conn.commit()
self.cursor.execute("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'users'
)
""")
if not self.cursor.fetchone()[0]:
raise RuntimeError("Database tables not initialized")
except Exception as e:
print(f"Database connection error: {str(e)}")
raise
def find_technician(self, service_type: str, longitude: float, latitude: float):
query = """
SELECT id, name, contact, rating
FROM technicians
WHERE %s = ANY(qualifications)
AND availability = 'available'
ORDER BY location <-> ST_SetSRID(ST_MakePoint(%s, %s), 4326)
LIMIT 1
"""
self.cursor.execute(query, (service_type, longitude, latitude))
result = self.cursor.fetchone()
if result:
return {
"id": result[0],
"name": result[1],
"contact": result[2],
"rating": result[3]
}
return None
def get_user_state(self, user_number: str):
query = "SELECT state, last_message FROM users WHERE number = %s"
self.cursor.execute(query, (user_number,))
result = self.cursor.fetchone()
return {"state": result[0], "last_message": result[1]} if result else None
def update_user_state(self, user_number: str, state: str, last_message: str):
query = """
INSERT INTO users (number, state, last_message)
VALUES (%s, %s, %s)
ON CONFLICT (number) DO UPDATE
SET state = %s, last_message = %s, updated_at = CURRENT_TIMESTAMP
"""
self.cursor.execute(query, (user_number, state, last_message, state, last_message))
self.conn.commit()
def save_request(self, user_number: str, technician_id: int, service_type: str):
query = """
INSERT INTO requests (user_number, technician_id, service_type, status)
VALUES (%s, %s, %s, %s)
RETURNING id
"""
self.cursor.execute(query, (user_number, technician_id, service_type, "pending"))
self.conn.commit()
return self.cursor.fetchone()[0]
def close(self):
self.cursor.close()
self.conn.close()
class NLPProcessor:
def __init__(self):
# Initialize all required attributes
self.api_url = "https://api-inference.huggingface.co/models/Christy123/service-classifier"
self.api_token = os.getenv("HF_API_TOKEN") # Make sure this is in your .env
self.ner_url = "https://api-inference.huggingface.co/models/dbmdz/bert-large-cased-finetuned-conll03-english"
self.service_mappings = {
"ac": "hvac",
"air conditioner": "hvac",
"plumb": "plumbing",
"pipe": "plumbing",
"electr": "electrical",
"wiring": "electrical"
}
def extract_service(self, text: str) -> str:
"""Enhanced service classification with fallback"""
try:
# Try Hugging Face API first
headers = {"Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}"}
response = requests.post(
"https://api-inference.huggingface.co/models/Christy123/service-classifier",
headers=headers,
json={"inputs": text},
timeout=5
)
if response.status_code == 200:
result = response.json()[0]
if result["score"] > 0.7: # Only accept confident predictions
return result["label"]
# Fallback to keyword matching
text_lower = text.lower()
for keyword, service in self.service_mappings.items():
if keyword in text_lower:
return service
return "unknown"
except Exception:
return "unknown"
def extract_location(self, text: str) -> str:
"""Enhanced location detection with fallback methods"""
headers = {"Authorization": f"Bearer {self.api_token}"}
# Try Hugging Face NER first
try:
response = requests.post(
"https://api-inference.huggingface.co/models/dbmdz/bert-large-cased-finetuned-conll03-english",
headers=headers,
json={"inputs": text},
timeout=5
)
if response.status_code == 200:
entities = response.json()
locations = [e["word"] for e in entities if e["entity_group"] == "LOC"]
if locations:
return locations[0]
except Exception:
pass
# Fallback 1: Simple keyword matching
kenyan_towns = ["Nairobi", "Mombasa", "Kisumu", "Nakuru", "Eldoret",
"Westlands", "Karen", "Runda", "Thika", "Naivasha"]
for town in kenyan_towns:
if town.lower() in text.lower():
return town
# Fallback 2: Look for "in <location>" pattern
import re
match = re.search(r"\bin\s+([A-Za-z]+)", text, re.IGNORECASE)
if match:
return match.group(1)
return None