policy-analysis / utils /generation_streaming.py
kaburia's picture
ZeroGPU
0ed976e
# from langchain_community.embeddings import HuggingFaceEmbeddings
# from langchain_community.embeddings import CrossEncoder
import requests
import numpy as np
import time
import json
# encode the text
from utils.encoding_input import encode_text
# rertrieve and rerank the documents
from utils.retrieve_n_rerank import retrieve_and_rerank
# sentiment analysis on reranked documents
from utils.sentiment_analysis import get_sentiment
# coherence assessment reports
from utils.coherence_bbscore import coherence_report
# Get the vectorstore
from utils.loading_embeddings import get_vectorstore
# build message from model generation
from utils.model_generation import build_messages
import os
import spaces
API_KEY = "sk-do-8Hjf0liuGQCoPwglilL49xiqrthMECwjGP_kAjPM53OTOFQczPyfPK8xJc"
MODEL = "llama3.3-70b-instruct"
@spaces.GPU(duration=120)
def generate_response_stream(query: str, enable_sentiment: bool, enable_coherence: bool):
# Initialize vectorstore when needed
vectorstore = get_vectorstore()
# encoded_input = encode_text(query)
reranked_results = retrieve_and_rerank(
query_text=query,
vectorstore=vectorstore,
k=50, # number of initial documents to retrieve
rerank_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
top_m=20, # number of documents to return after reranking
min_score=0.5, # minimum score for reranked documents
only_docs=False # return both documents and scores
)
top_docs = [doc for doc, score in reranked_results]
if not top_docs:
yield "No relevant documents found."
return
sentiment_rollup = get_sentiment(top_docs) if enable_sentiment else {}
coherence_report_ = coherence_report(reranked_results=top_docs, input_text= query) if enable_coherence else ""
messages = build_messages(
query=query,
top_docs=top_docs,
task_mode="verbatim_sentiment",
sentiment_rollup=sentiment_rollup,
coherence_report=coherence_report_,
)
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
data = {
"model": MODEL,
"messages": messages,
"temperature": 0.2,
"stream": True,
"max_tokens": 2000
}
collected = "" # Accumulate content to show
with requests.post("https://inference.do-ai.run/v1/chat/completions", headers=headers, json=data, stream=True) as r:
if r.status_code != 200:
yield f"[ERROR] API returned status {r.status_code}: {r.text}"
return
for line in r.iter_lines(decode_unicode=True):
if not line or line.strip() == "data: [DONE]":
continue
if line.startswith("data: "):
line = line[len("data: "):]
try:
chunk = json.loads(line)
delta = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
if delta:
collected += delta
yield collected # yield progressively
time.sleep(0.01) # slight throttle to improve smoothness
except Exception as e:
print("Streaming decode error:", e)
continue