cgpbot / app.py
chikamov1's picture
Integrate all chatbot logic into app.py and update Dockerfile
b8b0952
from fastapi import FastAPI
from pydantic import BaseModel
import json
import os
import requests
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from dotenv import load_dotenv
# --- Load environment variables ---
load_dotenv()
# --- Configuration ---
# Hugging Face Hub IDs for your trained MarianMT models
HF_EN_FR_REPO_ID = "cgpcorpbot/cgp_model_en-fr"
HF_FR_EN_REPO_ID = "cgpcorpbot/cgp_model_fr-en"
# Gemini API configuration
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
# Correct Gemini API URL for generateContent
GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
# --- Load MarianMT models and tokenizers using pipeline ---
# These will be loaded once when the app starts
try:
# Set device to CPU
device = torch.device("cpu")
print("Loading EN->FR model...")
tokenizer_en_fr = AutoTokenizer.from_pretrained(HF_EN_FR_REPO_ID)
model_en_fr = AutoModelForSeq2SeqLM.from_pretrained(HF_EN_FR_REPO_ID).to(device)
# Using pipeline for easier translation
translator_en_fr = pipeline("translation", model=model_en_fr, tokenizer=tokenizer_en_fr, device=device)
print(f"✅ MarianMT EN-FR Model loaded from Hugging Face Hub: {HF_EN_FR_REPO_ID} and moved to {device}")
print("Loading FR->EN model...")
tokenizer_fr_en = AutoTokenizer.from_pretrained(HF_FR_EN_REPO_ID)
model_fr_en = AutoModelForSeq2SeqLM.from_pretrained(HF_FR_EN_REPO_ID).to(device)
# Using pipeline for easier translation
translator_fr_en = pipeline("translation", model=model_fr_en, tokenizer=tokenizer_fr_en, device=device)
print(f"✅ MarianMT FR-EN Model loaded from Hugging Face Hub: {HF_FR_EN_REPO_ID} and moved to {device}")
except Exception as e:
print(f"❌ Failed to load MarianMT models from Hugging Face Hub: {e}")
raise RuntimeError(f"Failed to load translation models: {e}")
# --- Language Detection (Simplified) ---
def detect_language(text: str) -> str:
text_lower = text.lower()
french_keywords = ["le", "la", "les", "un", "une", "des", "est", "sont", "je", "tu", "il", "elle", "nous", "vous", "ils", "elles", "pas", "de", "du", "et", "à", "en", "que", "qui", "quoi", "comment", "où", "quand"]
french_word_count = sum(1 for word in french_keywords if word in text_lower.split())
if french_word_count > 2:
return "french"
return "english"
# --- MarianMT Translation with Gemini Fallback ---
def translate_text_with_fallback(text: str, direction: str) -> str:
"""
Translates text using custom MarianMT models, falling back to Gemini if local model fails.
`direction` can be "en-fr" or "fr-en".
"""
if not text.strip():
return ""
try:
if direction == "en-fr":
# Use the pipeline for translation
translated_result = translator_en_fr(text, max_length=128)
return translated_result[0]['translation_text'].strip()
elif direction == "fr-en":
# Use the pipeline for translation
translated_result = translator_fr_en(text, max_length=128)
return translated_result[0]['translation_text'].strip()
else:
return "Invalid translation direction."
except Exception as e:
print(f"⚠️ Local MarianMT model failed for {direction}, falling back to Gemini for translation: {e}")
# Fallback to Gemini for translation if MarianMT fails
return gemini_translate_for_translation(text, direction)
async def gemini_translate_for_translation(text: str, direction: str) -> str:
"""
Uses Gemini API for translation if MarianMT fails.
This is a separate function specifically for translation fallback,
not for general chatbot responses.
"""
if not GEMINI_API_KEY:
print("❌ Gemini API Key is missing for translation fallback.")
return "API key missing for translation."
target_lang = "French" if direction == "en-fr" else "English"
# Prompt Gemini to perform translation
prompt = f"Translate the following text to {target_lang}: \"{text}\""
chat_history = [{"role": "user", "parts": [{"text": prompt}]}]
payload = {"contents": chat_history}
headers = {'Content-Type': 'application/json'}
api_url_with_key = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}"
try:
response = requests.post(api_url_with_key, headers=headers, data=json.dumps(payload))
response.raise_for_status()
result = response.json()
return result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "Translation failed via Gemini.")
except Exception as e:
print(f"Gemini API translation fallback error: {e}")
return "Gemini API translation error."
# --- Main Gemini API Call for Conversational Response ---
async def call_gemini_api_for_response(prompt: str) -> str:
"""
Calls the Gemini API to get a conversational response in English.
This is the primary function for generating chatbot responses.
"""
if not GEMINI_API_KEY:
print("❌ Gemini API Key is missing for main response.")
return "API key missing."
chat_history = []
# Gemini will always be prompted in English for consistent behavior
gemini_prompt = f"Answer the following question in English: {prompt}"
chat_history.append({"role": "user", "parts": [{"text": gemini_prompt}]})
payload = {"contents": chat_history}
headers = {'Content-Type': 'application/json'}
api_url_with_key = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}"
try:
response = requests.post(api_url_with_key, headers=headers, data=json.dumps(payload))
response.raise_for_status()
result = response.json()
return result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "No response from Gemini.")
except Exception as e:
print(f"Gemini API error for main response: {e}")
return "Gemini API error for main response."
# --- Main Chatbot Response Logic ---
async def get_multilingual_chatbot_response(user_input: str) -> str:
"""
Generates a chatbot response using MarianMT for translation and Gemini for core logic.
Handles language detection, translation, Gemini interaction, and translation back.
"""
detected_lang = detect_language(user_input)
print(f"Detected language: {detected_lang.upper()}")
english_query = user_input
if detected_lang == "french":
print("Translating French query to English...")
# Use the translation function with fallback
english_query = await translate_text_with_fallback(user_input, "fr-en")
print(f"Translated query (EN): {english_query}")
if not english_query.strip() or english_query == "API key missing for translation." or english_query == "Gemini API translation error.":
# If translation fails or is empty, use original input for Gemini
english_query = user_input
print("French to English translation resulted in empty string or error, using original input for Gemini.")
# Get conversational response from Gemini (always in English)
gemini_response_en = await call_gemini_api_for_response(english_query)
print(f"Gemini response (EN): {gemini_response_en}")
final_response = gemini_response_en
# Only translate back if original input was French AND Gemini provided a valid response
if detected_lang == "french" and gemini_response_en not in ["API key missing for main response.", "Gemini API error for main response.", "No response from Gemini."]:
print("Translating English response back to French...")
# Use the translation function with fallback
translated_back_fr = await translate_text_with_fallback(gemini_response_en, "en-fr")
if translated_back_fr.strip() and translated_back_fr not in ["API key missing for translation.", "Gemini API translation error."]:
final_response = translated_back_fr
else:
print("English to French translation resulted in empty string or error, using English Gemini response.")
final_response = gemini_response_en # Fallback to English if translation back fails
return final_response
# --- FastAPI Application ---
app = FastAPI()
# Define the request body model
class ChatRequest(BaseModel):
user_input: str
@app.get("/")
async def root():
return {"message": "Multilingual Chatbot API is running. Use /chat endpoint."}
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""
Endpoint for the multilingual chatbot.
Receives user input, detects language, translates, gets Gemini response,
and translates back if necessary.
"""
user_input = request.user_input
if not user_input:
return {"response": "Please provide some input."}
response = await get_multilingual_chatbot_response(user_input)
return {"response": response}