Suguru1846's picture
Update app.py
79b024d verified
import os
import torch
from fastapi import FastAPI, File, UploadFile
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import traceback
import re
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import login
from pydantic import BaseModel, Field
from typing import Optional
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 1000
image: Optional[str] = Field(None, description="This field should be None. If an image is detected, the request will be rejected.")
class SentimentRequest(BaseModel):
text: str
# Use environment variable for token
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
else:
print("Warning: No HF_TOKEN found in environment variables")
# Set environment variables
os.environ["TRITON_DISABLE"] = "1"
os.environ["BNB_DISABLE_TRITON"] = "1"
os.environ["USE_TORCH"] = "1"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
os.makedirs("/tmp/hf_cache", exist_ok=True)
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["TORCH_HOME"] = "/tmp/hf_cache"
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load models once at startup
print("Loading models and tokenizers...")
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
device_map="auto",
torch_dtype=torch.float16
)
device = 0 if torch.cuda.is_available() else -1
sentiment_analyzer = pipeline(
"text-classification",
model="nlptown/bert-base-multilingual-uncased-sentiment",
return_all_scores=True,
device=device # Use GPU if available, otherwise CPU
)
print("Models and tokenizers loaded successfully!")
@app.post("/generate")
async def generate_text(
request: GenerateRequest = None,
prompt: str = None,
max_tokens: int = 1000,
file: Optional[UploadFile] = None
):
if file:
file = None # Just discard the file
if request:
user_prompt = request.prompt
tokens = request.max_tokens
else:
user_prompt = prompt
tokens = max_tokens
if not user_prompt:
return {"error": "No prompt provided"}
try:
formatted_prompt = f"<s>[INST] {user_prompt} [/INST]"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=tokens,
do_sample=True,
temperature=0.7,
top_p=0.9
)
raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# More aggressive cleaning to remove the user's message
if raw_response.startswith(formatted_prompt):
clean_response = raw_response[len(formatted_prompt):].strip()
else:
# Try to find where the instruction ends and the actual response begins
clean_response = raw_response.split("[/INST]")[-1].strip()
# Also remove the original prompt if it appears at the beginning
if clean_response.startswith(user_prompt):
clean_response = clean_response[len(user_prompt):].strip()
# Further clean up any remaining tags
clean_response = re.sub(r'</?s>|\[/?s\]|\[/?INST\]|\[/?INSR\]|\{/?INSST\}', '', clean_response).strip()
return {"response": clean_response}
except Exception as e:
error_msg = str(e)
error_trace = traceback.format_exc()
print(f"Error generating text: {error_msg}")
print(f"Traceback: {error_trace}")
return {"error": error_msg, "traceback": error_trace}
@app.post("/analyze_sentiment")
async def analyze_sentiment(request: SentimentRequest):
try:
result = sentiment_analyzer(request.text)
scores = {score["label"]: score["score"] for score in result[0]}
# Add debug logs
print(f"Raw scores for '{request.text}': {scores}")
# Define sentiment mapping for the 5-star model
sentiment_mapping = {
"1 star": "very_negative",
"2 stars": "negative",
"3 stars": "neutral",
"4 stars": "positive",
"5 stars": "very_positive"
}
# Get the highest score label and map it directly to sentiment
highest_score_label = max(scores.items(), key=lambda x: x[1])[0]
sentiment = sentiment_mapping[highest_score_label]
print(f"Final sentiment: {sentiment}")
return {"sentiment": sentiment, "raw_scores": scores}
except Exception as e:
error_msg = str(e)
error_trace = traceback.format_exc()
print(f"Error analyzing sentiment: {error_msg}")
print(f"Traceback: {error_trace}")
return {"error": error_msg, "traceback": error_trace}