StrandDemo / app.py
Madras1's picture
Upload app.py
cf48579 verified
"""
Strand Data - Demo Backend
Deploy em HuggingFace Spaces
Modelo: Madras1/sbert_cosine_filter_v3
Sistema de Âncora: Centróide de exemplos de alta qualidade
Endpoints:
- POST /classify-quality: Classifica qualidade com sBERT + âncora
- POST /similarity: Retorna score de similaridade com âncora
- POST /qa: Q&A sobre texto usando LLM
- POST /caption: Gera descrição de imagem
"""
import os
import base64
import httpx
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
import torch
import numpy as np
from typing import Optional
app = FastAPI(title="Strand Data Demo API")
# CORS para permitir requests do frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ================================
# Configuração
# ================================
# API Keys (usar secrets do HuggingFace)
CHUTES_API_KEY = os.getenv("CHUTES_API_KEY", "")
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
# Modelo sBERT - SEU MODELO FINE-TUNED
SBERT_MODEL_NAME = "Madras1/sbert_cosine_filter_v3"
# Threshold de qualidade (baseado no seu pipeline)
QUALITY_THRESHOLD = 0.65
print(f"🧠 Carregando modelo sBERT: {SBERT_MODEL_NAME}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"📱 Device: {device}")
sbert_model = SentenceTransformer(SBERT_MODEL_NAME)
sbert_model.to(device)
sbert_model.eval()
# ================================
# Sistema de Âncora de Qualidade
# ================================
# Possíveis caminhos do arquivo de âncora (HuggingFace Spaces pode variar)
POSSIBLE_ANCHOR_PATHS = [
"anchor_gold_vector.pt", # Mesmo diretório
"/app/anchor_gold_vector.pt", # Docker padrão
"/home/user/app/anchor_gold_vector.pt", # HF Spaces path
"../anchor_gold_vector.pt", # Um nível acima
]
ANCHOR_FILE_PATH = None
for path in POSSIBLE_ANCHOR_PATHS:
if os.path.exists(path):
ANCHOR_FILE_PATH = path
break
print(f"⚓ Procurando vetor âncora...")
if ANCHOR_FILE_PATH and os.path.exists(ANCHOR_FILE_PATH):
# Carregar o centróide pré-calculado do seu dataset de ouro
ANCHOR_EMBEDDING = torch.load(ANCHOR_FILE_PATH, map_location=device)
print(f"✅ Vetor âncora carregado de: {ANCHOR_FILE_PATH}")
print(f" Shape: {ANCHOR_EMBEDDING.shape}")
else:
# Fallback: calcular de exemplos hardcoded se arquivo não existir
print("⚠️ Arquivo de âncora não encontrado em nenhum caminho.")
print(f" Caminhos testados: {POSSIBLE_ANCHOR_PATHS}")
print(" Usando exemplos de fallback...")
FALLBACK_EXAMPLES = [
"Este artigo apresenta uma análise detalhada dos métodos de aprendizado de máquina aplicados à visão computacional, com resultados quantitativos robustos.",
"O estudo demonstra correlação significativa entre as variáveis analisadas, utilizando metodologia rigorosa e amostra representativa.",
"A implementação do algoritmo proposto apresenta complexidade O(n log n), com benchmarks comparativos contra soluções estado-da-arte.",
]
with torch.no_grad():
fallback_embeddings = sbert_model.encode(FALLBACK_EXAMPLES, convert_to_tensor=True)
ANCHOR_EMBEDDING = torch.mean(fallback_embeddings, dim=0)
print(" Âncora de fallback calculada.")
print(f" Threshold de qualidade: {QUALITY_THRESHOLD}")
# ================================
# Modelos de Request/Response
# ================================
class QualityRequest(BaseModel):
text: str
class QualityResponse(BaseModel):
quality: str # "high", "medium", "low"
similarity_score: float # Similaridade com âncora (0-1)
score_percent: float # Score em porcentagem (0-100)
threshold: float # Threshold usado
verdict: str # Descrição legível
class SimilarityRequest(BaseModel):
text: str
class SimilarityResponse(BaseModel):
similarity: float
is_high_quality: bool
class QARequest(BaseModel):
context: str
question: str
class QAResponse(BaseModel):
answer: str
class CaptionRequest(BaseModel):
image_base64: str
class CaptionResponse(BaseModel):
caption: str
# ================================
# Funções de Classificação
# ================================
def compute_quality_score(text: str) -> tuple[float, str, str]:
"""
Calcula score de qualidade usando similaridade de cosseno com âncora.
Retorna: (similarity_score, quality_label, verdict)
"""
with torch.no_grad():
# Encode com normalização para garantir cálculo correto de cosseno
text_embedding = sbert_model.encode(text, convert_to_tensor=True, normalize_embeddings=True)
# Normalizar o anchor também (se não estiver normalizado)
anchor_normalized = ANCHOR_EMBEDDING / torch.norm(ANCHOR_EMBEDDING)
# Debug
text_norm = torch.norm(text_embedding).item()
anchor_norm = torch.norm(anchor_normalized).item()
print(f"📊 DEBUG - Text embedding norm (deve ser ~1.0): {text_norm:.4f}")
print(f"📊 DEBUG - Anchor norm (deve ser ~1.0): {anchor_norm:.4f}")
print(f"📊 DEBUG - Text[:50]: {text[:50]}...")
# Similaridade de cosseno (com vetores normalizados = dot product)
similarity = util.cos_sim(text_embedding, anchor_normalized).item()
print(f"📊 DEBUG - Similaridade calculada: {similarity:.4f}")
# Classificação baseada no threshold
if similarity >= QUALITY_THRESHOLD:
quality = "high"
verdict = "✨ Texto de ALTA qualidade! Estrutura e conteúdo técnico excelentes."
elif similarity >= 0.45:
quality = "medium"
verdict = "📝 Qualidade MÉDIA. Tem potencial, mas pode ser aprimorado."
else:
quality = "low"
verdict = "⚠️ Qualidade BAIXA. Requer revisão significativa."
return similarity, quality, verdict
# ================================
# LLM Helpers
# ================================
async def call_llm(prompt: str, system: str = "", max_tokens: int = 500) -> str:
"""Chama LLM via Chutes ou OpenRouter."""
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
# Tentar Chutes primeiro
if CHUTES_API_KEY:
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(
"https://llm.chutes.ai/v1/chat/completions",
headers={
"Authorization": f"Bearer {CHUTES_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": "deepseek-ai/DeepSeek-V3-0324",
"messages": messages,
"max_tokens": max_tokens,
"temperature": 0.7
}
)
if response.status_code == 200:
return response.json()["choices"][0]["message"]["content"]
except Exception as e:
print(f"Erro Chutes: {e}")
# Fallback para OpenRouter
if OPENROUTER_API_KEY:
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(
"https://openrouter.ai/api/v1/chat/completions",
headers={
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": "nex-agi/deepseek-v3.1-nex-n1:free",
"messages": messages,
"max_tokens": max_tokens
}
)
if response.status_code == 200:
return response.json()["choices"][0]["message"]["content"]
except Exception as e:
print(f"Erro OpenRouter: {e}")
raise HTTPException(status_code=503, detail="Nenhuma API de LLM disponível")
async def call_vision_llm(image_base64: str, prompt: str) -> str:
"""Chama LLM multimodal para image captioning."""
# Modelos de visão na Chutes (em ordem de preferência)
vision_models = [
"Qwen/Qwen2.5-VL-72B-Instruct-TEE", # TEE
"Qwen/Qwen3-VL-235B-A22B-Instruct", # Qwen3
]
if CHUTES_API_KEY:
for model_name in vision_models:
try:
print(f"🖼️ Tentando modelo de visão: {model_name}")
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
"https://llm.chutes.ai/v1/chat/completions",
headers={
"Authorization": f"Bearer {CHUTES_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": model_name,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
]
}
],
"max_tokens": 300
}
)
print(f" Status: {response.status_code}")
if response.status_code == 200:
result = response.json()["choices"][0]["message"]["content"]
print(f" ✅ Sucesso com {model_name}")
return result
else:
print(f" ❌ Erro: {response.text[:200]}")
except Exception as e:
print(f" ❌ Exceção: {e}")
raise HTTPException(status_code=503, detail="API de visão não disponível. Verifique os logs do container.")
# ================================
# Endpoints
# ================================
@app.get("/")
async def root():
return {
"message": "Strand Data Demo API",
"status": "online",
"model": SBERT_MODEL_NAME,
"threshold": QUALITY_THRESHOLD
}
@app.get("/health")
async def health():
return {
"status": "healthy",
"model_loaded": sbert_model is not None,
"device": device,
"anchor_calibrated": ANCHOR_EMBEDDING is not None
}
@app.post("/classify-quality", response_model=QualityResponse)
async def classify_quality(request: QualityRequest):
"""
Classifica a qualidade de um texto usando sBERT + sistema de âncora.
Usa similaridade de cosseno com centróide de exemplos de alta qualidade.
"""
if not request.text.strip():
raise HTTPException(status_code=400, detail="Texto não pode estar vazio")
similarity, quality, verdict = compute_quality_score(request.text)
return QualityResponse(
quality=quality,
similarity_score=round(similarity, 4),
score_percent=round(similarity * 100, 2),
threshold=QUALITY_THRESHOLD,
verdict=verdict
)
@app.post("/similarity", response_model=SimilarityResponse)
async def compute_similarity(request: SimilarityRequest):
"""
Endpoint simples: retorna apenas a similaridade com a âncora.
Útil para filtragem em batch.
"""
if not request.text.strip():
raise HTTPException(status_code=400, detail="Texto não pode estar vazio")
with torch.no_grad():
text_embedding = sbert_model.encode(request.text, convert_to_tensor=True, normalize_embeddings=True)
anchor_normalized = ANCHOR_EMBEDDING / torch.norm(ANCHOR_EMBEDDING)
similarity = util.cos_sim(text_embedding, anchor_normalized).item()
return SimilarityResponse(
similarity=round(similarity, 4),
is_high_quality=similarity >= QUALITY_THRESHOLD
)
@app.post("/qa", response_model=QAResponse)
async def question_answering(request: QARequest):
"""Responde perguntas sobre um texto usando LLM."""
system_prompt = """Você é um assistente especializado em responder perguntas sobre textos.
Responda de forma precisa e concisa, baseando-se APENAS no contexto fornecido.
Se a resposta não estiver no contexto, diga "Não encontrei essa informação no texto."
Responda em português."""
prompt = f"""CONTEXTO:
{request.context}
PERGUNTA:
{request.question}
RESPOSTA:"""
answer = await call_llm(prompt, system_prompt, max_tokens=300)
return QAResponse(answer=answer.strip())
@app.post("/caption", response_model=CaptionResponse)
async def generate_caption(request: CaptionRequest):
"""Gera uma descrição/legenda para uma imagem."""
prompt = """Descreva esta imagem em detalhes.
Inclua: objetos principais, cores, ações, ambiente/cenário.
Responda em português, em 2-3 frases."""
caption = await call_vision_llm(request.image_base64, prompt)
return CaptionResponse(caption=caption.strip())
# ================================
# Para rodar localmente
# ================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)