|
|
import os |
|
|
import uuid |
|
|
import json |
|
|
import requests |
|
|
from typing import Dict, Any, Optional |
|
|
|
|
|
from fastapi import FastAPI, Body |
|
|
from pydantic import BaseModel, Field |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0") |
|
|
ADAPTER_ID = os.environ.get("ADAPTER_ID", "ethnmcl/tinyllama-entrepreneurchatbot-lora") |
|
|
XGB_SCORE_URL = os.environ.get("XGB_SCORE_URL", "https://ethnmcl-EntrepreneurialReadinessScoreAPI.hf.space/score") |
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
|
"You are a concise entrepreneurial readiness assistant. " |
|
|
"Be clear, specific, and professional." |
|
|
) |
|
|
|
|
|
QUESTION_FLOW = [ |
|
|
{"key": "age", "question": "What is your age? (number)"}, |
|
|
{"key": "savings", "question": "How much do you currently have saved (USD)?"}, |
|
|
{"key": "monthly_expense_ratio", "question": "What is your monthly expense ratio (expenses/income)? (e.g., 0.55)"}, |
|
|
{"key": "sales_experience", "question": "Rate your sales experience from 0–10."}, |
|
|
{"key": "dependents", "question": "How many dependents do you support? (number)"}, |
|
|
{"key": "weekly_time_commitment", "question": "How many hours/week can you commit to your venture?"}, |
|
|
] |
|
|
TYPE_CASTS = { |
|
|
"age": float, |
|
|
"savings": float, |
|
|
"monthly_expense_ratio": float, |
|
|
"sales_experience": float, |
|
|
"dependents": float, |
|
|
"weekly_time_commitment": float, |
|
|
} |
|
|
|
|
|
EXAMPLES = [ |
|
|
"Can you explain what the entrepreneurial readiness check is in one or two sentences?", |
|
|
"I am 31, with $8,000 savings and an expense ratio of 0.72. What does that say about my readiness?", |
|
|
"I’m 27, working 15 hours a week on my business, with 2 dependents. How might that affect my entrepreneurial score?", |
|
|
"Inputs → age 29, savings 5000, expense ratio 0.62, sales experience 4, dependents 1, hours/week 12. Summarize likely strengths and risks in bullet points.", |
|
|
"Based on these inputs (savings 3000, expense ratio 0.85, sales exp 2), what 3 actions should I take to improve my readiness?", |
|
|
"I feel nervous about launching with only $2,000 saved. Can you give me encouragement and one practical step?", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SESSIONS: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
def new_session() -> str: |
|
|
sid = uuid.uuid4().hex |
|
|
SESSIONS[sid] = {"answers": {}, "idx": 0, "active": False, "scored": False, "score": None} |
|
|
return sid |
|
|
|
|
|
def get_session(session_id: Optional[str]) -> str: |
|
|
if session_id and session_id in SESSIONS: |
|
|
return session_id |
|
|
return new_session() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
|
global tokenizer, model |
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True) |
|
|
|
|
|
use_4bit = torch.cuda.is_available() |
|
|
kwargs = {"device_map": "auto"} |
|
|
if use_4bit: |
|
|
|
|
|
kwargs.update(dict(load_in_4bit=True, torch_dtype=torch.float16)) |
|
|
else: |
|
|
|
|
|
kwargs.update(dict(torch_dtype=torch.float32)) |
|
|
|
|
|
base = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, **kwargs) |
|
|
try: |
|
|
from peft import PeftModel |
|
|
model = PeftModel.from_pretrained(base, ADAPTER_ID) |
|
|
except Exception as e: |
|
|
|
|
|
print(f"[WARN] Failed to load PEFT adapter {ADAPTER_ID}: {e}") |
|
|
model = base |
|
|
|
|
|
model.eval() |
|
|
print("[INFO] Model ready. CUDA:", torch.cuda.is_available()) |
|
|
|
|
|
def llm_reply(system: str, user: str, max_new_tokens=180, temperature=0.2) -> str: |
|
|
prompt = f"<|system|>\n{system}\n<|user|>\n{user}\n<|assistant|>\n" |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
with torch.inference_mode(): |
|
|
out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=temperature > 0, |
|
|
temperature=temperature, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
text = tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
return text.split("<|assistant|>")[-1].strip() |
|
|
|
|
|
def next_question(state: Dict[str, Any]) -> Optional[str]: |
|
|
if state["idx"] < len(QUESTION_FLOW): |
|
|
return QUESTION_FLOW[state["idx"]]["question"] |
|
|
return None |
|
|
|
|
|
def cast_features(answers: Dict[str, str]) -> Dict[str, Any]: |
|
|
out = {} |
|
|
for k, v in answers.items(): |
|
|
caster = TYPE_CASTS.get(k, str) |
|
|
try: |
|
|
out[k] = caster(v) |
|
|
except Exception: |
|
|
out[k] = v |
|
|
return out |
|
|
|
|
|
def score_via_api(features: Dict[str, Any]) -> Dict[str, Any]: |
|
|
try: |
|
|
r = requests.post(XGB_SCORE_URL, json={"features": features}, timeout=20) |
|
|
r.raise_for_status() |
|
|
return r.json() |
|
|
except Exception as e: |
|
|
return {"error": f"Scoring API error: {e}"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Entrepreneurial Readiness Chat API", version="1.0.0") |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] |
|
|
) |
|
|
|
|
|
class ChatIn(BaseModel): |
|
|
message: str = Field(..., description="User message") |
|
|
session_id: Optional[str] = Field(None, description="Session ID (optional)") |
|
|
|
|
|
class ChatOut(BaseModel): |
|
|
session_id: str |
|
|
reply: str |
|
|
assessment: Dict[str, Any] |
|
|
|
|
|
class StartOut(BaseModel): |
|
|
session_id: str |
|
|
question: str |
|
|
idx: int |
|
|
total: int |
|
|
|
|
|
class AnswerIn(BaseModel): |
|
|
session_id: str |
|
|
answer: str |
|
|
|
|
|
class AnswerOut(BaseModel): |
|
|
session_id: str |
|
|
done: bool |
|
|
question: Optional[str] = None |
|
|
idx: int |
|
|
total: int |
|
|
score: Optional[Any] = None |
|
|
error: Optional[str] = None |
|
|
|
|
|
@app.on_event("startup") |
|
|
def _startup(): |
|
|
load_model() |
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return {"ok": True} |
|
|
|
|
|
@app.get("/examples") |
|
|
def examples(): |
|
|
return {"examples": EXAMPLES} |
|
|
|
|
|
@app.post("/assessment/start", response_model=StartOut) |
|
|
def assessment_start(payload: Dict[str, Any] = Body(default={})): |
|
|
sid = get_session(payload.get("session_id")) |
|
|
st = SESSIONS[sid] |
|
|
st.update({"answers": {}, "idx": 0, "active": True, "scored": False, "score": None}) |
|
|
q = next_question(st) |
|
|
return {"session_id": sid, "question": q, "idx": st["idx"] + 1, "total": len(QUESTION_FLOW)} |
|
|
|
|
|
@app.post("/assessment/answer", response_model=AnswerOut) |
|
|
def assessment_answer(inp: AnswerIn): |
|
|
if inp.session_id not in SESSIONS: |
|
|
return AnswerOut(session_id=new_session(), done=False, idx=0, total=len(QUESTION_FLOW), error="Invalid session_id") |
|
|
|
|
|
st = SESSIONS[inp.session_id] |
|
|
if not st.get("active"): |
|
|
return AnswerOut(session_id=inp.session_id, done=False, idx=st["idx"], total=len(QUESTION_FLOW), error="Assessment not active.") |
|
|
|
|
|
cur_key = QUESTION_FLOW[st["idx"]]["key"] |
|
|
st["answers"][cur_key] = inp.answer.strip() |
|
|
st["idx"] += 1 |
|
|
|
|
|
q = next_question(st) |
|
|
if q is None: |
|
|
st["active"] = False |
|
|
features = cast_features(st["answers"]) |
|
|
res = score_via_api(features) |
|
|
st["scored"] = "error" not in res |
|
|
st["score"] = res |
|
|
return AnswerOut(session_id=inp.session_id, done=True, idx=len(QUESTION_FLOW), total=len(QUESTION_FLOW), score=res) |
|
|
else: |
|
|
return AnswerOut(session_id=inp.session_id, done=False, question=q, idx=st["idx"] + 1, total=len(QUESTION_FLOW)) |
|
|
|
|
|
@app.post("/chat", response_model=ChatOut) |
|
|
def chat(inp: ChatIn): |
|
|
sid = get_session(inp.session_id) |
|
|
st = SESSIONS[sid] |
|
|
|
|
|
|
|
|
msg_lower = inp.message.lower() |
|
|
triggers = ["take the entrepreneurial readiness assessment", "take assessment", "start assessment", "readiness assessment"] |
|
|
if any(t in msg_lower for t in triggers): |
|
|
st.update({"answers": {}, "idx": 0, "active": True, "scored": False, "score": None}) |
|
|
q = next_question(st) |
|
|
reply = ( |
|
|
"Great—let’s do a short 6-question entrepreneurial readiness check.\n\n" |
|
|
f"**Q1/6**: {q}" |
|
|
) |
|
|
return ChatOut( |
|
|
session_id=sid, |
|
|
reply=reply, |
|
|
assessment={"active": True, "idx": st["idx"] + 1, "total": len(QUESTION_FLOW), "question": q} |
|
|
) |
|
|
|
|
|
|
|
|
answer = llm_reply(SYSTEM_PROMPT, inp.message) |
|
|
return ChatOut( |
|
|
session_id=sid, |
|
|
reply=answer, |
|
|
assessment={"active": st.get("active", False), "idx": st.get("idx", 0), "total": len(QUESTION_FLOW)} |
|
|
) |
|
|
|