File size: 3,624 Bytes
a7f10f6
8825a51
 
 
a7f10f6
 
8825a51
 
 
a7f10f6
8825a51
e7065b3
8825a51
 
a7f10f6
8825a51
 
 
 
 
a7f10f6
 
 
8825a51
a7f10f6
 
 
 
 
 
 
 
 
 
 
e7065b3
 
a7f10f6
 
 
d318a0b
a7f10f6
e7065b3
a7f10f6
8825a51
 
 
a7f10f6
 
 
 
 
 
 
 
8825a51
 
e7065b3
8825a51
 
 
 
e7065b3
8825a51
 
 
a7f10f6
 
 
 
e7065b3
a7f10f6
 
 
 
 
 
e7065b3
e41db88
 
d318a0b
e7065b3
d318a0b
e7065b3
 
 
 
97d5afe
 
e7065b3
e41db88
 
a7f10f6
e7065b3
d318a0b
 
e7065b3
d318a0b
 
 
 
 
 
 
a7f10f6
e7065b3
 
 
a7f10f6
e7065b3
a7f10f6
 
 
e7065b3
a7f10f6
 
8825a51
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import os

# Initialize FastAPI
app = FastAPI(title="Davidic Sermon Intelligence API")

# Add CORS Middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load Models
print("Loading Embedding model...")
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

print("Loading Reranker model...")
reranker_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

print("Loading Tiny LLM (TinyLlama-1.1B)...")
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm_model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=torch.float32, 
    low_cpu_mem_usage=True
)

# Pipeline WITHOUT generation config to avoid warnings
llm_pipeline = pipeline(
    "text-generation", 
    model=llm_model, 
    tokenizer=tokenizer
)
print("All models loaded Ready.")

class EmbedRequest(BaseModel):
    text: str

class RerankRequest(BaseModel):
    query: str
    documents: list[str]

class InsightRequest(BaseModel):
    query: str
    context: str

@app.get("/")
def health_check():
    return {"status": "running"}

@app.post("/embed")
def embed(request: EmbedRequest):
    try:
        return embedding_model.encode(request.text).tolist()
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/rerank")
def rerank(request: RerankRequest):
    try:
        pairs = [[request.query, doc] for doc in request.documents]
        return reranker_model.predict(pairs).tolist()
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/insight")
def generate_insight(request: InsightRequest):
    try:
        print(f"Generating insight for: {request.query}")
        prompt = (
            f"<|system|>\n"
            f"You are a helpful spiritual assistant for Davidic Generation Church. "
            f"Explain the spiritual context of the videos below based on their transcripts.\n"
            f"RULES:\n"
            f"1. Refer to videos like this: 'In [Video 1], Pastor explains...'.\n"
            f"2. Summarize WHY this moment is relevant to the question.\n"
            f"3. Do NOT just repeat the transcript. Explain the meaning.\n"
            f"4. Be thorough and long-form.\n"
            f"<|user|>\n"
            f"CONTEXT:\n{request.context}\n\n"
            f"QUESTION: {request.query}\n"
            f"<|assistant|>\n"
        )
        
        # Explicitly set ALL parameters here
        output = llm_pipeline(
            prompt, 
            max_new_tokens=512,
            temperature=0.7,
            do_sample=True,
            top_k=50,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
        
        result = output[0]['generated_text']
        if "<|assistant|>" in result:
            insight = result.split("<|assistant|>")[-1].strip()
        else:
            insight = result[len(prompt):].strip()
            
        return {"insight": insight}
    except Exception as e:
        print(f"Error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)