|
|
import os |
|
|
import json |
|
|
import re |
|
|
from typing import List, Dict, Any, Set |
|
|
import faiss |
|
|
import numpy as np |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from utils import base_utils as bu |
|
|
from utils import retrieval_utils as ru |
|
|
|
|
|
class QueryRequest(BaseModel): |
|
|
question: str |
|
|
mode: str |
|
|
top_k: int | None = None |
|
|
temperature: float | None = None |
|
|
|
|
|
class QueryResponse(BaseModel): |
|
|
answer: str |
|
|
retrieved: List[Dict[str, Any]] |
|
|
|
|
|
|
|
|
def load_index_and_metadata(config_path: str = "configs/config.json"): |
|
|
config = bu.load_config(config_path) |
|
|
|
|
|
index_path = config["index"].get("index_file", "data/index/faiss.index") |
|
|
metadata_path = config["index"].get("metadata_file", "data/index/metadata.jsonl") |
|
|
|
|
|
if not os.path.exists(index_path) or not os.path.exists(metadata_path): |
|
|
raise FileNotFoundError( |
|
|
"Index or metadata not found. Run scripts/build_index.py first." |
|
|
) |
|
|
|
|
|
index = faiss.read_index(index_path) |
|
|
|
|
|
metadata: List[Dict[str, Any]] = [] |
|
|
with open(metadata_path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
metadata.append(json.loads(line)) |
|
|
|
|
|
return index, metadata, config |
|
|
|
|
|
|
|
|
def load_embedding_model(config: dict): |
|
|
model_name = config["embeddings"]["model_name"] |
|
|
return ru.load_model(model_name) |
|
|
|
|
|
|
|
|
def embed_query(model, text: str) -> np.ndarray: |
|
|
vec = model.encode(text) |
|
|
vec = np.asarray(vec, dtype="float32")[None, :] |
|
|
faiss.normalize_L2(vec) |
|
|
return vec |
|
|
|
|
|
|
|
|
def _normalizar_texto(texto: str) -> str: |
|
|
"""Limpia saltos de línea y espacios raros en los fragmentos.""" |
|
|
|
|
|
if not texto: |
|
|
return "" |
|
|
|
|
|
|
|
|
texto = texto.replace("\r", " ").replace("\n", " ") |
|
|
texto = re.sub(r"\s+", " ", texto) |
|
|
return texto.strip() |
|
|
|
|
|
|
|
|
def call_llm(config: dict, question: str, context_fragments: List[Dict[str, Any]], temperature: float | None = None, mode="chatbot") -> str: |
|
|
|
|
|
|
|
|
if not context_fragments: |
|
|
return "Não há informação suficiente na base. (nenhum trecho recuperado)" |
|
|
|
|
|
provider = config.get("llm", {}).get("provider", "openai") |
|
|
model_name = config.get("llm", {}).get("model", "gpt-4o-mini") |
|
|
|
|
|
|
|
|
citation_ids: Dict[str, int] = {} |
|
|
next_id = 1 |
|
|
for m in context_fragments: |
|
|
doc_id = m.get("document_id") |
|
|
if not doc_id: |
|
|
continue |
|
|
if doc_id not in citation_ids: |
|
|
citation_ids[doc_id] = next_id |
|
|
next_id += 1 |
|
|
|
|
|
|
|
|
context_lines: List[str] = [] |
|
|
for m in context_fragments: |
|
|
doc_id = m.get("document_id") |
|
|
frag_id = m.get("fragment_id") |
|
|
if not doc_id: |
|
|
continue |
|
|
cit_num = citation_ids.get(doc_id) |
|
|
tag = f"[{cit_num}] (trecho #{frag_id})" if cit_num is not None else f"[{doc_id}#{frag_id}]" |
|
|
texto = _normalizar_texto(m.get("content", "")) |
|
|
if len(texto) > 500: |
|
|
texto = texto[:500] + "..." |
|
|
context_lines.append(f"{tag} {texto}") |
|
|
context_text = "\n".join(context_lines) |
|
|
|
|
|
|
|
|
referencias_lines: List[str] = [] |
|
|
for doc_id, cit_num in sorted(citation_ids.items(), key=lambda x: x[1]): |
|
|
titulo = None |
|
|
for m in context_fragments: |
|
|
if m.get("document_id") == doc_id: |
|
|
titulo = m.get("document_title") or doc_id |
|
|
break |
|
|
referencias_lines.append(f"[{cit_num}] {titulo}") |
|
|
referencias_text = "\n".join(referencias_lines) |
|
|
|
|
|
|
|
|
system_prompt = PROMPTS.get(mode) |
|
|
if not system_prompt: |
|
|
system_prompt = "Você é um assistente especializado em química e NORM. Use SOMENTE os trechos de contexto fornecidos e nunca invente dados. Se a pergunta mencionar anos posteriores a 2025 (por exemplo 2026 em diante), responda explicitamente que não há informações disponíveis para esses anos na base e não tente extrapolar. Quando citar fontes, use o formato [n], onde n corresponde ao número da referência na tabela de referências fornecida. " |
|
|
|
|
|
|
|
|
if provider == "openai": |
|
|
try: |
|
|
from openai import OpenAI |
|
|
except ImportError: |
|
|
return ( |
|
|
"Não há informação suficiente na base. " |
|
|
"(biblioteca openai não instalada no ambiente)" |
|
|
) |
|
|
|
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
|
if not api_key: |
|
|
return ( |
|
|
"Não há informação suficiente na base. " |
|
|
"(variável de ambiente OPENAI_API_KEY não configurada)" |
|
|
) |
|
|
|
|
|
client = OpenAI(api_key=api_key) |
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": ( |
|
|
"Contexto:\n" + context_text + "\n\n" + |
|
|
"Referências (por número):\n" + referencias_text + "\n\n" + |
|
|
f"Pergunta: {question}" |
|
|
), |
|
|
}, |
|
|
] |
|
|
|
|
|
temperature_value = 0.5 if temperature is None else float(temperature) |
|
|
|
|
|
resp = client.chat.completions.create( |
|
|
model=model_name, |
|
|
messages=messages, |
|
|
temperature=temperature_value, |
|
|
) |
|
|
return resp.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
resumen = [] |
|
|
for line in context_lines: |
|
|
if len(line) > 300: |
|
|
line = line[:300] + "..." |
|
|
resumen.append(line) |
|
|
joined = "\n".join(resumen) |
|
|
return ( |
|
|
"Provvedor de LLM não suportado; mostrando trechos brutos:\n" + joined |
|
|
) |
|
|
|
|
|
index, METADATA, CONFIG = load_index_and_metadata() |
|
|
EMBED_MODEL = load_embedding_model(CONFIG) |
|
|
PROMPTS = bu.load_prompts() |
|
|
|
|
|
app = FastAPI(title="chatbot-Norm") |
|
|
|
|
|
@app.get("/list_documents") |
|
|
async def list_documents() -> Dict[str, Any]: |
|
|
"""Retorna a lista de documentos presentes na metadata, com IDs e títulos.""" |
|
|
|
|
|
|
|
|
docs: Dict[str, str] = {} |
|
|
for m in METADATA: |
|
|
doc_id = m.get("document_id") |
|
|
if not doc_id: |
|
|
continue |
|
|
|
|
|
titulo = m.get("document_title") or doc_id |
|
|
if doc_id not in docs: |
|
|
docs[doc_id] = titulo |
|
|
|
|
|
|
|
|
documentos_ordenados = [ |
|
|
{"id": doc_id, "title": docs[doc_id]} for doc_id in sorted(docs, key=lambda d: docs[d].lower()) |
|
|
] |
|
|
|
|
|
return {"documents": documentos_ordenados} |
|
|
|
|
|
@app.post("/query", response_model=QueryResponse) |
|
|
async def query_endpoint(payload: QueryRequest) -> QueryResponse: |
|
|
|
|
|
top_k = payload.top_k or CONFIG.get("retrieve", {}).get("top_k", 4) |
|
|
temperature = payload.temperature |
|
|
mode = payload.mode |
|
|
|
|
|
|
|
|
q_vec = embed_query(EMBED_MODEL, payload.question) |
|
|
scores, indices = index.search(q_vec, top_k) |
|
|
idxs = indices[0].tolist() |
|
|
|
|
|
|
|
|
retrieved: List[Dict[str, Any]] = [] |
|
|
meta_by_idx = {m["idx"]: m for m in METADATA} |
|
|
for i in idxs: |
|
|
m = meta_by_idx.get(i) |
|
|
if m is not None: |
|
|
retrieved.append(dict(m)) |
|
|
|
|
|
|
|
|
|
|
|
citation_ids: Dict[str, int] = {} |
|
|
next_id = 1 |
|
|
for m in retrieved: |
|
|
doc_id = m.get("document_id") |
|
|
if not doc_id: |
|
|
continue |
|
|
if doc_id not in citation_ids: |
|
|
citation_ids[doc_id] = next_id |
|
|
next_id += 1 |
|
|
m["citation_id"] = citation_ids[doc_id] |
|
|
|
|
|
|
|
|
answer = call_llm(CONFIG, payload.question, retrieved, temperature, mode) |
|
|
|
|
|
|
|
|
if answer.strip().startswith("Não há informação suficiente na base") or \ |
|
|
answer.strip().startswith("Desculpe, mas não tenho informações"): |
|
|
return QueryResponse(answer=answer, retrieved=[]) |
|
|
|
|
|
|
|
|
saludo_regex = r"^(ol[áa]|hola|hello|hi)[!,.]?" |
|
|
saludo_match = re.match(saludo_regex, answer.strip().lower()) |
|
|
generic_prefixes = [ |
|
|
"eu sou um assistente", |
|
|
"sou um assistente", |
|
|
"sou um assistente virtual", |
|
|
"eu sou um assistente virtual" |
|
|
] |
|
|
if saludo_match or any(answer.strip().lower().startswith(p) for p in generic_prefixes): |
|
|
return QueryResponse(answer=answer, retrieved=[]) |
|
|
|
|
|
import re as _re |
|
|
raw_citations = [int(m.group(1)) for m in _re.finditer(r"\[(\d+)\]", answer)] |
|
|
|
|
|
if raw_citations: |
|
|
|
|
|
unique_old_numbers = list(dict.fromkeys(raw_citations)) |
|
|
|
|
|
renumber_map = {old: new for new, old in enumerate(unique_old_numbers, start=1)} |
|
|
|
|
|
for old, new in renumber_map.items(): |
|
|
answer = answer.replace(f"[{old}]", f"[TEMP_{new}]") |
|
|
|
|
|
answer = answer.replace("TEMP_", "") |
|
|
|
|
|
answer = _re.sub(r'\[(\d+)\](\s*\[\1\])+', r'[\1]', answer) |
|
|
|
|
|
for m in retrieved: |
|
|
old_id = m.get("citation_id") |
|
|
if old_id in renumber_map: |
|
|
m["citation_id"] = renumber_map[old_id] |
|
|
else: |
|
|
m["citation_id"] = None |
|
|
|
|
|
retrieved = [m for m in retrieved if m.get("citation_id") is not None] |
|
|
|
|
|
else: |
|
|
retrieved = [] |
|
|
|
|
|
|
|
|
return QueryResponse(answer=answer, retrieved=retrieved) |
|
|
|
|
|
|