mod types; use types::{VerifyRequest, VerifyResponse, Evidence}; use std::env; use actix_web::{get, post, web, App, HttpServer, Responder, HttpResponse}; use actix_cors::Cors; use qdrant_client::Qdrant; use qdrant_client::qdrant::QueryPointsBuilder; use ort::session::Session; use ort::value::Tensor; use tokenizers::Tokenizer; use deadpool::managed::{Manager, RecycleResult}; // Bi-Encoder pool manager struct BiEncoderManager; impl Manager for BiEncoderManager { type Type = Session; type Error = ort::Error; async fn create(&self) -> Result { Session::builder()? .commit_from_file("models/bge_small/model.onnx") } async fn recycle(&self, _obj: &mut Self::Type, _: &deadpool::managed::Metrics) -> RecycleResult { Ok(()) } } // Cross-Encoder pool manager struct CrossEncoderManager; impl Manager for CrossEncoderManager { type Type = Session; type Error = ort::Error; async fn create(&self) -> Result { Session::builder()? .commit_from_file("models/pubmedbert/model_quantized.onnx") } async fn recycle(&self, _obj: &mut Self::Type, _: &deadpool::managed::Metrics) -> RecycleResult { Ok(()) } } // Define the Pool types for convenience type BiEncoderPool = deadpool::managed::Pool; type CrossEncoderPool = deadpool::managed::Pool; // This struct holds our shared application state struct AppState { qdrant_client: Qdrant, bi_encoder_pool: BiEncoderPool, bi_encoder_tokenizer: Tokenizer, cross_encoder_pool: CrossEncoderPool, cross_encoder_tokenizer: Tokenizer, } const COLLECTION_NAME: &str = "verity_hybrid_corpus"; // Health check endpoint #[get("/")] async fn health_check() -> impl Responder { HttpResponse::Ok().body("Verity Backend is running on Hugging Face Spaces!") } // Endpoint for verifying a claim #[post("/api/verify")] async fn verify_claim( req_body: web::Json, data: web::Data, ) -> impl Responder { let claim: &str = &req_body.claim; // 250 equates to 60-70 tokens and prevents Premise Length Dilution in the Cross-Encoder. if claim.chars().count() > 250 { return HttpResponse::BadRequest().body("Claim is too long. Please limit your claim to 250 characters."); } println!("============================================================================================================"); println!("Received claim: {}", claim); println!("============================================================================================================\n"); // BGE-Small strictly requires this exact prefix for search queries let bi_encoder_query = format!("Represent this sentence for searching relevant passages: {}", claim); // Embed the claim using the BGE Small model let encoding = data.bi_encoder_tokenizer.encode(bi_encoder_query, true).expect("Failed to encode claim using the Bi-Encoder tokenizer"); // BGE Small is build on a BERT style architecture and requires these three inputs // ONNX expects 64 bit integers for input ids and masks // Ids of the tokens in the sequence let input_ids: Vec = encoding.get_ids().iter().map(|&x| x as i64).collect(); // Tells the model which tokens are padding and which are actual tokens. (1 for actual tokens, 0 for padding) let attention_mask: Vec = encoding.get_attention_mask().iter().map(|&x| x as i64).collect(); // Tells the model which tokens belong to which sequence. (0 for the first sequence, 1 for the second, etc.) let token_type_ids: Vec = encoding.get_type_ids().iter().map(|&x| x as i64).collect(); let batch_size: usize = 1; // This is the number of sequences let seq_len: usize = input_ids.len(); // This is the number of tokens in the sequence let shape = vec![batch_size, seq_len]; // Pass a standard Rust Tuple (Shape, Data) directly to ONNX. // Get the Bi-Encoder from the pool asynchronously let mut bi_encoder = data.bi_encoder_pool.get().await.expect("Failed to get Bi-Encoder from pool"); // Run the Bi-Encoder inference in a blocking thread let embedding: Vec = web::block(move || { let outputs = bi_encoder.run(ort::inputs![ "input_ids" => Tensor::from_array((shape.clone(), input_ids)).unwrap(), "attention_mask" => Tensor::from_array((shape.clone(), attention_mask)).unwrap(), "token_type_ids" => Tensor::from_array((shape.clone(), token_type_ids)).unwrap(), ]).expect("Failed to run the Bi-Encoder model"); let (_shape, tensor_data) = outputs["last_hidden_state"] .try_extract_tensor::() .expect("Failed to extract float tensor from ONNX output"); tensor_data[0..384].to_vec() }) .await .expect("Blocking task for Bi-Encoder failed"); // Extract the dynamic threshold from the request (default to 0.55) let threshold: f32 = req_body.qdrant_threshold.unwrap_or(0.55); // Query Qdrant for top 5 matches let query_request = QueryPointsBuilder::new(COLLECTION_NAME) .query(embedding) .limit(9) .score_threshold(threshold) .with_payload(true); let response_result = data.qdrant_client .query(query_request).await; let response = match response_result { Ok(response) => { response } Err(e) => { println!("Error querying Qdrant: {:?}", e); return HttpResponse::InternalServerError().body("Error querying Qdrant"); } }; // Perform dynamic radius retrieval let radius: f32 = 0.05; let top_result = response.result.first().map(|hit| hit.score).unwrap_or(0.0); let response = response.result.into_iter().filter(|hit| hit.score >= top_result - radius).collect::>(); // Trackers for document-level thresholded weighted sum pooling let mut weighted_support_sum: f32 = 0.0; let mut weighted_refute_sum: f32 = 0.0; let mut highest_neutral_score: f32 = 0.0; let mut evidence_list: Vec = Vec::new(); // Get a cross-encoder model from the pool asynchronously let mut cross_encoder = data.cross_encoder_pool.get().await.expect("Failed to get Cross-Encoder from pool"); for hit in response.iter() { let payload = &hit.payload; // Helper to get fields from Qdrant response let get_string = |key: &str| { payload.get(key).and_then(|v| v.as_str()).map(|s| s.as_str()).unwrap_or("Unknown") }; let title = get_string("title"); let source = get_string("dataset_source"); let abstract_text = get_string("abstract"); let _score = hit.score; println!("----------------------------------------------------------------------------------------------------"); println!("Score: {:.4} | Title: {} [{}]", _score, title, source); println!("----------------------------------------------------------------------------------------------------\n"); // Variables for chunk-level max pooling let mut best_support: f32 = 0.0; let mut best_refute: f32 = 0.0; let mut max_signal: f32 = 0.0; // Clean and split the abstract text into sentences let sentences: Vec = abstract_text .split(". ") .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(|s| format!("{}.", s)) .collect(); let window_size = 2; let mut chunks: Vec = Vec::new(); // If there are less than two sentences, combine them. If there are more, perform windowing if sentences.len() <= window_size { chunks.push(sentences.join(" ")); } else { for window in sentences.windows(window_size) { chunks.push(window.join(" ")); } } let batch_size = chunks.len(); if batch_size == 0 { continue; } // Preapre the inputs for the tokenizer. A list of tuples (Chunk, Claim) let encoding_inputs: Vec<(String, String)> = chunks .iter() .map(|chunk| (chunk.clone(), claim.to_string())) .collect(); // Encode the entire batch at once. The tokenizer will pad them automatically let encodings = data.cross_encoder_tokenizer .encode_batch(encoding_inputs, true) .unwrap(); // Since they are padded, we can get the sequence length from the first encoding let seq_len = encodings[0].get_ids().len(); // Flatten the 2D batch into 1D vectors for ONNX let mut input_ids: Vec = Vec::with_capacity(batch_size * seq_len); let mut attention_mask: Vec = Vec::with_capacity(batch_size * seq_len); let mut token_type_ids: Vec = Vec::with_capacity(batch_size * seq_len); for encoding in encodings { input_ids.extend(encoding.get_ids().iter().map(|&x| x as i64)); attention_mask.extend(encoding.get_attention_mask().iter().map(|&x| x as i64)); token_type_ids.extend(encoding.get_type_ids().iter().map(|&x| x as i64)); } let shape = vec![batch_size as i64, seq_len as i64]; // Offload the batched inference to a blocking thread let (logits_result, returned_cross_encoder) = web::block(move || { let logits_res = { let result = cross_encoder.run(ort::inputs![ "input_ids" => Tensor::from_array((shape.clone(), input_ids)).unwrap(), "attention_mask" => Tensor::from_array((shape.clone(), attention_mask)).unwrap(), "token_type_ids" => Tensor::from_array((shape.clone(), token_type_ids)).unwrap(), ]); match result { Ok(outputs) => { match outputs["logits"].try_extract_tensor::() { Ok((_shape, logits_data)) => Ok(logits_data.to_vec()), Err(e) => Err(ort::Error::from(e)) } } Err(e) => Err(e) } }; (logits_res, cross_encoder) }) .await .expect("Cross-Encoder blocking failed"); // Get cross-encoder back for next iteration cross_encoder = returned_cross_encoder; let flat_logits = match logits_result { Ok(logits) => logits, Err(e) => { println!("Warning: Skipping document. ONNX batched inference failed. Error: {:?}", e); continue; } }; // Process the flattened logits array. // Every 3 items in the array correspond to [Refute, Support, Neutral] for one chunk.\ for logits_chunk in flat_logits.chunks(3) { // Convert logits to SoftMax probabilities let max_logit: f32 = logits_chunk[0].max(logits_chunk[1].max(logits_chunk[2])); let exp_logits: Vec = logits_chunk.iter().map(|&x| (x - max_logit).exp()).collect(); let sum_logits: f32 = exp_logits.iter().sum(); let softmax_probs: Vec = exp_logits.iter().map(|&x| x / sum_logits).collect(); // Extract probabilities for each class let refute_prob = softmax_probs[0]; let support_prob = softmax_probs[1]; let chunk_signal = refute_prob.max(support_prob); // Keep only the highest signal chunk for chunk-level max pooling if chunk_signal > max_signal { max_signal = chunk_signal; best_support = support_prob; best_refute = refute_prob; } } let stance; let confidence; // If confidence is lower than 0.50 for the document we assume stance as neutral since the cross-encoder uncertainty if best_support > best_refute && best_support > 0.50 { stance = "SUPPORT".to_string(); confidence = best_support; } else if best_refute > best_support && best_refute > 0.50 { stance = "REFUTE".to_string(); confidence = best_refute; } else { stance = "NEUTRAL".to_string(); confidence = 1.0 - best_support - best_refute; } // Perform document-level thresholded weighted sum pooling let weighted_confidence = confidence * _score; // Only include documents with confidence > 0.80 for the weighted thresholded sum pooling if confidence > 0.80 { if stance == "SUPPORT" { weighted_support_sum += weighted_confidence; } else if stance == "REFUTE" { weighted_refute_sum += weighted_confidence; } } // Track the highest neutral score just in case all documents are filtered out if stance == "NEUTRAL" && confidence > highest_neutral_score { highest_neutral_score = confidence; } // Find the byte index of the 200th character, or the end of the string let end_index = abstract_text .char_indices() .map(|(i, _)| i) .nth(200) .unwrap_or(abstract_text.len()); let snippet = if abstract_text.len() > end_index { format!("{}...", &abstract_text[..end_index]) } else { abstract_text.to_string() }; evidence_list.push(types::Evidence { title: title.to_string(), source: source.to_string(), snippet, stance, confidence, }); } // Map output to match the True/False expectation for your benchmark script let mut final_verdict = "NEUTRAL".to_string(); let aggregate_confidence; // The stance with the most accumulated weighted evidence wins if weighted_support_sum > weighted_refute_sum && weighted_support_sum > 0.0 { final_verdict = "TRUE".to_string(); // Normalize the confidence let total_sum = weighted_support_sum + weighted_refute_sum; aggregate_confidence = weighted_support_sum / total_sum; } else if weighted_refute_sum > weighted_support_sum && weighted_refute_sum > 0.0 { final_verdict = "FALSE".to_string(); let total_sum = weighted_support_sum + weighted_refute_sum; aggregate_confidence = weighted_refute_sum / total_sum; } else { // Fallback if no documents passed the 0.80 threshold aggregate_confidence = highest_neutral_score; } let response = VerifyResponse { final_verdict, aggregate_confidence, evidence: evidence_list, }; HttpResponse::Ok().json(response) } #[actix_web::main] async fn main() -> std::io::Result<()> { // Initialize ONNX runtime engine ort::init().with_name("verity_inference_engine").commit(); // Initialize Qdrant Client (Connecting over the Docker network) let qdrant_url = env::var("QDRANT_URL").ok(); let qdrant_api_key = env::var("QDRANT_API_KEY").ok(); let client_result = if let (Some(url), Some(key)) = (qdrant_url, qdrant_api_key) { println!("\nBackend connecting to Qdrant at {}...", url); Qdrant::from_url(&url) .api_key(key) .build() } else { let local_url = "http://qdrant:6334"; println!("\nBackend connecting to local Qdrant at {}...", local_url); Qdrant::from_url(local_url) .build() }; let client = match client_result { Ok(client) => { println!("Backend connected to Qdrant successfully."); client } Err(e) => { println!("Failed to connect to Qdrant: {:?}", e); return Err(std::io::Error::new(std::io::ErrorKind::Other, "Backend failed to connect to Qdrant")); } }; // Configure the model pools let pool_size = 2; let bi_encoder_pool = deadpool::managed::Pool::builder(BiEncoderManager) .max_size(pool_size) .build() .expect("Failed to create Bi-Encoder pool"); let cross_encoder_pool = deadpool::managed::Pool::builder(CrossEncoderManager) .max_size(pool_size) .build() .expect("Failed to create Cross-Encoder pool"); // Initialize Cross-Encoder Tokenizer with padding enabled let mut cross_encoder_tokenizer = Tokenizer::from_file("models/pubmedbert/tokenizer.json").unwrap(); // Enable automatic padding to the longest sequence in the batch cross_encoder_tokenizer.with_padding(Some(tokenizers::PaddingParams { strategy: tokenizers::PaddingStrategy::BatchLongest, ..Default::default() })); // Create app state let app_state = web::Data::new(AppState { qdrant_client: client, bi_encoder_pool, bi_encoder_tokenizer: Tokenizer::from_file("models/bge_small/tokenizer.json").unwrap(), cross_encoder_pool, cross_encoder_tokenizer, // Use configured tokenizer with padding enabled }); // Dynamically get port let port = std::env::var("PORT").unwrap_or_else(|_| "7860".to_string()); let bind_address = format!("0.0.0.0:{}", port); println!("Started Verity Rust API on {}...", bind_address); // Start the HTTP Server HttpServer::new(move || { // Configure CORS let cors = Cors::default() .allowed_origin("https://verity-frontend-enqx.vercel.app") .allowed_origin("http://localhost:5173") .allowed_methods(vec!["GET", "POST"]) .allow_any_header() .max_age(3600); App::new() .wrap(cors) .app_data(app_state.clone()) .service(health_check) .service(verify_claim) }) .bind(&bind_address)? .run() .await }