import os import torch from fastapi import FastAPI, File, UploadFile from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import traceback import re from fastapi.middleware.cors import CORSMiddleware from huggingface_hub import login from pydantic import BaseModel, Field from typing import Optional class GenerateRequest(BaseModel): prompt: str max_tokens: int = 1000 image: Optional[str] = Field(None, description="This field should be None. If an image is detected, the request will be rejected.") class SentimentRequest(BaseModel): text: str # Use environment variable for token HF_TOKEN = os.environ.get("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) else: print("Warning: No HF_TOKEN found in environment variables") # Set environment variables os.environ["TRITON_DISABLE"] = "1" os.environ["BNB_DISABLE_TRITON"] = "1" os.environ["USE_TORCH"] = "1" os.environ["BITSANDBYTES_NOWELCOME"] = "1" os.makedirs("/tmp/hf_cache", exist_ok=True) os.environ["HF_HOME"] = "/tmp/hf_cache" os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" os.environ["TORCH_HOME"] = "/tmp/hf_cache" app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load models once at startup print("Loading models and tokenizers...") model_name = "mistralai/Mistral-7B-Instruct-v0.3" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, device_map="auto", torch_dtype=torch.float16 ) device = 0 if torch.cuda.is_available() else -1 sentiment_analyzer = pipeline( "text-classification", model="nlptown/bert-base-multilingual-uncased-sentiment", return_all_scores=True, device=device # Use GPU if available, otherwise CPU ) print("Models and tokenizers loaded successfully!") @app.post("/generate") async def generate_text( request: GenerateRequest = None, prompt: str = None, max_tokens: int = 1000, file: Optional[UploadFile] = None ): if file: file = None # Just discard the file if request: user_prompt = request.prompt tokens = request.max_tokens else: user_prompt = prompt tokens = max_tokens if not user_prompt: return {"error": "No prompt provided"} try: formatted_prompt = f"[INST] {user_prompt} [/INST]" inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=tokens, do_sample=True, temperature=0.7, top_p=0.9 ) raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # More aggressive cleaning to remove the user's message if raw_response.startswith(formatted_prompt): clean_response = raw_response[len(formatted_prompt):].strip() else: # Try to find where the instruction ends and the actual response begins clean_response = raw_response.split("[/INST]")[-1].strip() # Also remove the original prompt if it appears at the beginning if clean_response.startswith(user_prompt): clean_response = clean_response[len(user_prompt):].strip() # Further clean up any remaining tags clean_response = re.sub(r'|\[/?s\]|\[/?INST\]|\[/?INSR\]|\{/?INSST\}', '', clean_response).strip() return {"response": clean_response} except Exception as e: error_msg = str(e) error_trace = traceback.format_exc() print(f"Error generating text: {error_msg}") print(f"Traceback: {error_trace}") return {"error": error_msg, "traceback": error_trace} @app.post("/analyze_sentiment") async def analyze_sentiment(request: SentimentRequest): try: result = sentiment_analyzer(request.text) scores = {score["label"]: score["score"] for score in result[0]} # Add debug logs print(f"Raw scores for '{request.text}': {scores}") # Define sentiment mapping for the 5-star model sentiment_mapping = { "1 star": "very_negative", "2 stars": "negative", "3 stars": "neutral", "4 stars": "positive", "5 stars": "very_positive" } # Get the highest score label and map it directly to sentiment highest_score_label = max(scores.items(), key=lambda x: x[1])[0] sentiment = sentiment_mapping[highest_score_label] print(f"Final sentiment: {sentiment}") return {"sentiment": sentiment, "raw_scores": scores} except Exception as e: error_msg = str(e) error_trace = traceback.format_exc() print(f"Error analyzing sentiment: {error_msg}") print(f"Traceback: {error_trace}") return {"error": error_msg, "traceback": error_trace}